Skip to content

Multi thread to curl LLaDA2.0-mini dInfer server will have more and more wrong answer for HumanEval #31

@AIxyz

Description

@AIxyz

When client use 2(or more) thread, the server of LLaDA2.0-mini in 1*A100-40G will have more and more wrong answer for HumanEval.

Image

But when client use 1 thread, it is OK.

Do you know the reason? @zheng-da @lundu28

My dInfer code is master a8b4a06

The server as follows:

app = FastAPI(
    title='DLLM Server',
    redoc_url=None,
    docs=None,
)


def get_dllm_model(world_size, rank, gpu_id, device, args):
    torch.cuda.set_device(gpu_id)

    tokenizer = AutoTokenizer.from_pretrained(SPECIAL_MODEL_DIR,
                                              trust_remote_code=True)
    block_length = args.block_length

    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = args.port
    distributed.init_distributed_environment(world_size, rank, 'env://', rank,
                                             'nccl')
    distributed.initialize_model_parallel(args.tp_size,
                                          args.ep_size,
                                          1,
                                          backend='nccl')
    print("[Loading model]")

    model_config = AutoConfig.from_pretrained(SPECIAL_MODEL_DIR,
                                              trust_remote_code=True)
    if args.use_quant:
        load_quant_config(model_config, SPECIAL_MODEL_DIR)
        server_args = ServerArgs(model_path=SPECIAL_MODEL_DIR,
                                 quantization="modelopt_fp8",
                                 modelopt_quant="fp8",
                                 enable_dp_attention=True,
                                 trust_remote_code=True,
                                 tp_size=args.tp_size,
                                 dp_size=1,
                                 pp_size=1)
    else:
        server_args = ServerArgs(model_path=SPECIAL_MODEL_DIR,
                                 enable_dp_attention=True,
                                 trust_remote_code=True,
                                 tp_size=args.tp_size,
                                 dp_size=1,
                                 pp_size=1)
    try:
        from sglang.srt.server_args import set_global_server_args_for_scheduler
    except ImportError:
        pass
    else:
        set_global_server_args_for_scheduler(server_args)
    initialize_dp_attention(
        server_args=server_args,
        model_config=model_config,
    )
    initialize_moe_config(server_args)
    if args.use_quant:
        model = LLaDA2SGLangLM(config=model_config,
                               quant_config=model_config.quant_config,
                               expert_map_path='.').eval()
    else:
        model = LLaDA2SGLangLM(config=model_config, expert_map_path='.').eval()
    torch.set_default_dtype(torch.bfloat16)
    model.load_weights(SPECIAL_MODEL_DIR, device=device)
    initialize_moe_config(server_args)

    model = model.to(device)
    model.after_processing(
    )  # if model is quantized, use quant_method.process_weights_after_loading
    max_length = int(os.environ.get('TASK_DLLM_MAX_LENGTH', '4096'))
    model = ModelRunner(model,
                        device,
                        server_args=server_args,
                        max_length=max_length,
                        enable_cuda_graph=True,
                        supported_batch_sizes=[args.batch_size])

    if args.parallel_decoding == 'threshold':
        if args.use_credit:
            decoder = CreditThresholdParallelDecoder(temperature=0,
                                                     threshold=args.threshold,
                                                     mask_id=156895,
                                                     eos_id=156892)
        else:
            decoder = ThresholdParallelDecoder(temperature=0,
                                               threshold=args.threshold,
                                               mask_id=156895,
                                               eos_id=156892)

    else:
        decoder = HierarchyDecoder(temperature=0,
                                   threshold=args.threshold,
                                   low_threshold=args.low_threshold,
                                   mask_id=156895,
                                   eos_id=156892)
    use_sw = args.prefix_look > 0 or args.after_look > 0 or args.warmup_times > 0
    if args.cache == 'prefix' or args.cache == 'dual':
        cache_factory = KVCacheFactory(args.cache,
                                       is_bd_model=args.use_bd,
                                       backend='sglang',
                                       max_length=max_length)
    else:
        cache_factory = None

    if not args.use_bd:
        if args.cont_weight > 0:
            if use_sw:
                dllm = IterSmoothWithVicinityCacheDiffusionLLM(
                    model,
                    decoder,
                    BlockIteratorFactory(start_block_align=True),
                    cache_factory=cache_factory,
                    early_stop=True,
                    cont_weight=args.cont_weight,
                    prefix_look=args.prefix_look,
                    after_look=args.after_look,
                    warmup_steps=args.warmup_times)
            else:
                dllm = IterSmoothDiffusionLLM(
                    model,
                    decoder,
                    BlockIteratorFactory(start_block_align=True),
                    cache_factory=cache_factory,
                    early_stop=True,
                    cont_weight=args.cont_weight)
        else:
            if use_sw:
                dllm = VicinityCacheDiffusionLLM(
                    model,
                    decoder,
                    BlockIteratorFactory(start_block_align=True),
                    cache_factory=cache_factory,
                    early_stop=True,
                    prefix_look=args.prefix_look,
                    after_look=args.after_look,
                    warmup_steps=args.warmup_times)
            else:
                dllm = BlockWiseDiffusionLLM(
                    model,
                    decoder,
                    BlockIteratorFactory(start_block_align=True),
                    cache_factory=cache_factory,
                    early_stop=True,
                    use_shift=args.use_shift)
    else:
        dllm = BlockDiffusionLLM(model,
                                 decoder,
                                 BlockIteratorFactory(
                                     start_block_align=True,
                                     use_block_diffusion=True),
                                 cache_factory=cache_factory,
                                 early_stop=True,
                                 maximum_unroll=1,
                                 expected_tpf=15,
                                 backend='sglang')
    # warmup for decoding algorithms
    input_ids = torch.arange(64, dtype=torch.long, device=device).unsqueeze(0)
    dllm.generate(input_ids,
                  gen_length=args.gen_len,
                  block_length=args.block_length)

    return tokenizer, dllm


args = get_args()
MODEL_DEVICE = torch.device(TASK_DINFER_GPU_ID)
tokenizer, MODEL_DLLM = get_dllm_model(TASK_DINFER_WORLD_SIZE,
                                       TASK_DINFER_RANK, TASK_DINFER_GPU_ID,
                                       MODEL_DEVICE, args)


def get_answer_no_stream(chat_uuid: str, data: Dict) -> str:
    input_ids = tokenizer.apply_chat_template(
        data['messages'],
        add_generation_prompt=True,
        tokenize=True,
        return_tensors='pt',
    ).to(MODEL_DEVICE)

    x_tokens_final = MODEL_DLLM.generate(input_ids,
                                         gen_length=TASK_DLLM_GEN_LENGTH,
                                         block_length=TASK_DLLM_BLOCK_LENGTH)

    resp = {}

    # x_str = tokenizer.decode(x_tokens_final[0])
    x_str = tokenizer.decode(x_tokens_final)  # for dInfer v0.2.0
    text = x_str.split('<role>ASSISTANT</role>')[-1]
    text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '')
    text = text.replace('<|mask|>', ' ')
    resp = construct_response(STATUS_OK, '', 'dllm', [{
        'text': text
    }], [])

    logging.info('[%s] resp: %s', chat_uuid, resp)
    return 'data: ' + json.dumps(resp) + '\n'


@app.post('/chat')
def chat(request: CompletionsRequest):
    ''' chat
    '''
    chat_uuid = f'chat-{str(uuid.uuid4())}'
    data = {
        'user_id': request.user_id,
        'stream': request.stream,
        'messages': request.messages,
    }
    logging.info('[%s] req: %s', chat_uuid, data)
    if request.stream:
        return StreamingResponse(get_answer(chat_uuid, data))
    return Response(
        content=get_answer_no_stream(chat_uuid, data),
        media_type='text/plain',
    )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions