From 3d4950b0f096413edb5b78b0fabda3dc81890b0f Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Fri, 26 Apr 2024 13:24:28 +0800 Subject: [PATCH] LLM: Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example (#10876) Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example. --- .../run_llama2_7b_chat_hf_arc_2_card.sh | 3 + .../GPU/Deepspeed-AutoTP-FastAPI/serving.py | 89 +++++++++++++++---- 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh index baea6e65..283499f2 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh +++ b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh @@ -30,6 +30,9 @@ export USE_XETLA=OFF export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 export TORCH_LLM_ALLREDUCE=0 +export WORLD_SIZE=2 +export MAX_NUM_SEQS=16 + mpirun -np $NUM_GPUS --prepend-rank \ python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 diff --git a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py index 9533473d..da4cc8df 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py @@ -25,6 +25,12 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn +import asyncio, uuid +from typing import Dict, List, Optional + +from transformers.utils import logging +logger = logging.get_logger(__name__) + def get_int_from_env(env_keys, default): """Returns the first positive env value found in the `env_keys` list or the default.""" for e in env_keys: @@ -35,6 +41,7 @@ def get_int_from_env(env_keys, default): local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0") world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1") +max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16") os.environ["RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") @@ -70,7 +77,7 @@ def load_model(model_path, low_bit): model = deepspeed.init_inference( model, - mp_size=world_size, + tensor_parallel={"tp_size": world_size}, dtype=torch.bfloat16, replace_method="auto", ) @@ -96,11 +103,22 @@ def load_model(model_path, low_bit): init_distributed() # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token -def generate_text(prompt: str, n_predict: int = 32): - input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}') +def generate_text(prompt: List[str], n_predict = 32): + while prompt[-1] == "": + prompt = prompt[:-1] + if isinstance(n_predict, list): + n_predict = max(n_predict) + + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(f'xpu:{local_rank}') + # print(input_ids) + attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}') output = model.generate(input_ids, + attention_mask=attention_mask, max_new_tokens=n_predict, use_cache=True) torch.xpu.synchronize() @@ -111,19 +129,59 @@ class PromptRequest(BaseModel): prompt: str n_predict: int = 32 +empty_req = PromptRequest(prompt="", n_predict=0) + app = FastAPI() +request_queue: asyncio.Queue = asyncio.Queue() +result_dict: Dict[str, str] = {} + @app.post("/generate/") async def generate(prompt_request: PromptRequest): - if local_rank == 0: - object_list = [prompt_request] - dist.broadcast_object_list(object_list, src=0) - start_time = time.time() - output = generate_text(object_list[0].prompt, object_list[0].n_predict) - generate_time = time.time() - start_time - output = output.cpu() - output_str = tokenizer.decode(output[0], skip_special_tokens=True) - return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'} + request_id = str(uuid.uuid4()) + await request_queue.put((request_id, prompt_request)) + while True: + await asyncio.sleep(0.1) + if request_id in result_dict: + output_str = result_dict.pop(request_id) + return {"generated_text": output_str} + + +async def process_requests(): + while True: + request_ids, prompt_requests = [], [] + for _ in range(max_num_seqs): + if request_queue.empty(): + break + request_id, prompt_request = await request_queue.get() + request_ids.append(request_id) + prompt_requests.append(prompt_request) + + if local_rank == 0 and prompt_requests: + # import pdb + # pdb.set_trace() + object_list = prompt_requests + if len(object_list) < max_num_seqs: + object_list = object_list + [empty_req] * (max_num_seqs - len(object_list)) + logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}") + dist.broadcast_object_list(object_list, src=0) + start_time = time.time() + outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list]) + generate_time = time.time() - start_time + outputs = outputs.cpu() + output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True) + output_strs = output_strs[:len(prompt_requests)] + + for request_id, output_str in zip(request_ids, output_strs): + result_dict[request_id] = output_str + + await asyncio.sleep(0.1) + + +@app.on_event("startup") +async def startup_event(): + asyncio.create_task(process_requests()) + if __name__ == "__main__": parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP') @@ -143,7 +201,6 @@ if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=args.port) else: while True: - object_list = [None] + object_list = [None] * max_num_seqs dist.broadcast_object_list(object_list, src=0) - output = generate_text(object_list[0].prompt, object_list[0].n_predict) - + output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])