LLM: Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example (#10876)

Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example.
This commit is contained in:
Xiangyu Tian 2024-04-26 13:24:28 +08:00 committed by GitHub
parent 3e8ed54270
commit 3d4950b0f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 76 additions and 16 deletions

View file

@ -30,6 +30,9 @@ export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0 export TORCH_LLM_ALLREDUCE=0
export WORLD_SIZE=2
export MAX_NUM_SEQS=16
mpirun -np $NUM_GPUS --prepend-rank \ mpirun -np $NUM_GPUS --prepend-rank \
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000

View file

@ -25,6 +25,12 @@ from fastapi import FastAPI, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import uvicorn import uvicorn
import asyncio, uuid
from typing import Dict, List, Optional
from transformers.utils import logging
logger = logging.get_logger(__name__)
def get_int_from_env(env_keys, default): def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default.""" """Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys: for e in env_keys:
@ -35,6 +41,7 @@ def get_int_from_env(env_keys, default):
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0") local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1") world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
os.environ["RANK"] = str(local_rank) os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
@ -70,7 +77,7 @@ def load_model(model_path, low_bit):
model = deepspeed.init_inference( model = deepspeed.init_inference(
model, model,
mp_size=world_size, tensor_parallel={"tp_size": world_size},
dtype=torch.bfloat16, dtype=torch.bfloat16,
replace_method="auto", replace_method="auto",
) )
@ -96,11 +103,22 @@ def load_model(model_path, low_bit):
init_distributed() init_distributed()
# Load tokenizer # Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt: str, n_predict: int = 32): def generate_text(prompt: List[str], n_predict = 32):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}') while prompt[-1] == "":
prompt = prompt[:-1]
if isinstance(n_predict, list):
n_predict = max(n_predict)
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
# print(input_ids)
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
output = model.generate(input_ids, output = model.generate(input_ids,
attention_mask=attention_mask,
max_new_tokens=n_predict, max_new_tokens=n_predict,
use_cache=True) use_cache=True)
torch.xpu.synchronize() torch.xpu.synchronize()
@ -111,19 +129,59 @@ class PromptRequest(BaseModel):
prompt: str prompt: str
n_predict: int = 32 n_predict: int = 32
empty_req = PromptRequest(prompt="", n_predict=0)
app = FastAPI() app = FastAPI()
request_queue: asyncio.Queue = asyncio.Queue()
result_dict: Dict[str, str] = {}
@app.post("/generate/") @app.post("/generate/")
async def generate(prompt_request: PromptRequest): async def generate(prompt_request: PromptRequest):
if local_rank == 0: request_id = str(uuid.uuid4())
object_list = [prompt_request] await request_queue.put((request_id, prompt_request))
while True:
await asyncio.sleep(0.1)
if request_id in result_dict:
output_str = result_dict.pop(request_id)
return {"generated_text": output_str}
async def process_requests():
while True:
request_ids, prompt_requests = [], []
for _ in range(max_num_seqs):
if request_queue.empty():
break
request_id, prompt_request = await request_queue.get()
request_ids.append(request_id)
prompt_requests.append(prompt_request)
if local_rank == 0 and prompt_requests:
# import pdb
# pdb.set_trace()
object_list = prompt_requests
if len(object_list) < max_num_seqs:
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
dist.broadcast_object_list(object_list, src=0) dist.broadcast_object_list(object_list, src=0)
start_time = time.time() start_time = time.time()
output = generate_text(object_list[0].prompt, object_list[0].n_predict) outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
generate_time = time.time() - start_time generate_time = time.time() - start_time
output = output.cpu() outputs = outputs.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True) output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'} output_strs = output_strs[:len(prompt_requests)]
for request_id, output_str in zip(request_ids, output_strs):
result_dict[request_id] = output_str
await asyncio.sleep(0.1)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests())
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP') parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
@ -143,7 +201,6 @@ if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=args.port) uvicorn.run(app, host="0.0.0.0", port=args.port)
else: else:
while True: while True:
object_list = [None] object_list = [None] * max_num_seqs
dist.broadcast_object_list(object_list, src=0) dist.broadcast_object_list(object_list, src=0)
output = generate_text(object_list[0].prompt, object_list[0].n_predict) output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])