-
Notifications
You must be signed in to change notification settings - Fork 40
Open
Description
When client use 2(or more) thread, the server of LLaDA2.0-mini in 1*A100-40G will have more and more wrong answer for HumanEval.
But when client use 1 thread, it is OK.
Do you know the reason? @zheng-da @lundu28
My dInfer code is master a8b4a06
The server as follows:
app = FastAPI(
title='DLLM Server',
redoc_url=None,
docs=None,
)
def get_dllm_model(world_size, rank, gpu_id, device, args):
torch.cuda.set_device(gpu_id)
tokenizer = AutoTokenizer.from_pretrained(SPECIAL_MODEL_DIR,
trust_remote_code=True)
block_length = args.block_length
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = args.port
distributed.init_distributed_environment(world_size, rank, 'env://', rank,
'nccl')
distributed.initialize_model_parallel(args.tp_size,
args.ep_size,
1,
backend='nccl')
print("[Loading model]")
model_config = AutoConfig.from_pretrained(SPECIAL_MODEL_DIR,
trust_remote_code=True)
if args.use_quant:
load_quant_config(model_config, SPECIAL_MODEL_DIR)
server_args = ServerArgs(model_path=SPECIAL_MODEL_DIR,
quantization="modelopt_fp8",
modelopt_quant="fp8",
enable_dp_attention=True,
trust_remote_code=True,
tp_size=args.tp_size,
dp_size=1,
pp_size=1)
else:
server_args = ServerArgs(model_path=SPECIAL_MODEL_DIR,
enable_dp_attention=True,
trust_remote_code=True,
tp_size=args.tp_size,
dp_size=1,
pp_size=1)
try:
from sglang.srt.server_args import set_global_server_args_for_scheduler
except ImportError:
pass
else:
set_global_server_args_for_scheduler(server_args)
initialize_dp_attention(
server_args=server_args,
model_config=model_config,
)
initialize_moe_config(server_args)
if args.use_quant:
model = LLaDA2SGLangLM(config=model_config,
quant_config=model_config.quant_config,
expert_map_path='.').eval()
else:
model = LLaDA2SGLangLM(config=model_config, expert_map_path='.').eval()
torch.set_default_dtype(torch.bfloat16)
model.load_weights(SPECIAL_MODEL_DIR, device=device)
initialize_moe_config(server_args)
model = model.to(device)
model.after_processing(
) # if model is quantized, use quant_method.process_weights_after_loading
max_length = int(os.environ.get('TASK_DLLM_MAX_LENGTH', '4096'))
model = ModelRunner(model,
device,
server_args=server_args,
max_length=max_length,
enable_cuda_graph=True,
supported_batch_sizes=[args.batch_size])
if args.parallel_decoding == 'threshold':
if args.use_credit:
decoder = CreditThresholdParallelDecoder(temperature=0,
threshold=args.threshold,
mask_id=156895,
eos_id=156892)
else:
decoder = ThresholdParallelDecoder(temperature=0,
threshold=args.threshold,
mask_id=156895,
eos_id=156892)
else:
decoder = HierarchyDecoder(temperature=0,
threshold=args.threshold,
low_threshold=args.low_threshold,
mask_id=156895,
eos_id=156892)
use_sw = args.prefix_look > 0 or args.after_look > 0 or args.warmup_times > 0
if args.cache == 'prefix' or args.cache == 'dual':
cache_factory = KVCacheFactory(args.cache,
is_bd_model=args.use_bd,
backend='sglang',
max_length=max_length)
else:
cache_factory = None
if not args.use_bd:
if args.cont_weight > 0:
if use_sw:
dllm = IterSmoothWithVicinityCacheDiffusionLLM(
model,
decoder,
BlockIteratorFactory(start_block_align=True),
cache_factory=cache_factory,
early_stop=True,
cont_weight=args.cont_weight,
prefix_look=args.prefix_look,
after_look=args.after_look,
warmup_steps=args.warmup_times)
else:
dllm = IterSmoothDiffusionLLM(
model,
decoder,
BlockIteratorFactory(start_block_align=True),
cache_factory=cache_factory,
early_stop=True,
cont_weight=args.cont_weight)
else:
if use_sw:
dllm = VicinityCacheDiffusionLLM(
model,
decoder,
BlockIteratorFactory(start_block_align=True),
cache_factory=cache_factory,
early_stop=True,
prefix_look=args.prefix_look,
after_look=args.after_look,
warmup_steps=args.warmup_times)
else:
dllm = BlockWiseDiffusionLLM(
model,
decoder,
BlockIteratorFactory(start_block_align=True),
cache_factory=cache_factory,
early_stop=True,
use_shift=args.use_shift)
else:
dllm = BlockDiffusionLLM(model,
decoder,
BlockIteratorFactory(
start_block_align=True,
use_block_diffusion=True),
cache_factory=cache_factory,
early_stop=True,
maximum_unroll=1,
expected_tpf=15,
backend='sglang')
# warmup for decoding algorithms
input_ids = torch.arange(64, dtype=torch.long, device=device).unsqueeze(0)
dllm.generate(input_ids,
gen_length=args.gen_len,
block_length=args.block_length)
return tokenizer, dllm
args = get_args()
MODEL_DEVICE = torch.device(TASK_DINFER_GPU_ID)
tokenizer, MODEL_DLLM = get_dllm_model(TASK_DINFER_WORLD_SIZE,
TASK_DINFER_RANK, TASK_DINFER_GPU_ID,
MODEL_DEVICE, args)
def get_answer_no_stream(chat_uuid: str, data: Dict) -> str:
input_ids = tokenizer.apply_chat_template(
data['messages'],
add_generation_prompt=True,
tokenize=True,
return_tensors='pt',
).to(MODEL_DEVICE)
x_tokens_final = MODEL_DLLM.generate(input_ids,
gen_length=TASK_DLLM_GEN_LENGTH,
block_length=TASK_DLLM_BLOCK_LENGTH)
resp = {}
# x_str = tokenizer.decode(x_tokens_final[0])
x_str = tokenizer.decode(x_tokens_final) # for dInfer v0.2.0
text = x_str.split('<role>ASSISTANT</role>')[-1]
text = text.replace('<|endoftext|>', '').replace('<|role_end|>', '')
text = text.replace('<|mask|>', ' ')
resp = construct_response(STATUS_OK, '', 'dllm', [{
'text': text
}], [])
logging.info('[%s] resp: %s', chat_uuid, resp)
return 'data: ' + json.dumps(resp) + '\n'
@app.post('/chat')
def chat(request: CompletionsRequest):
''' chat
'''
chat_uuid = f'chat-{str(uuid.uuid4())}'
data = {
'user_id': request.user_id,
'stream': request.stream,
'messages': request.messages,
}
logging.info('[%s] req: %s', chat_uuid, data)
if request.stream:
return StreamingResponse(get_answer(chat_uuid, data))
return Response(
content=get_answer_no_stream(chat_uuid, data),
media_type='text/plain',
)Metadata
Metadata
Assignees
Labels
No labels