LLM: Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example (#10876)

Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example.
This commit is contained in:
Xiangyu Tian 2024-04-26 13:24:28 +08:00 committed by GitHub
parent 3e8ed54270
commit 3d4950b0f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 76 additions and 16 deletions

View file

@ -30,6 +30,9 @@ export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0
export WORLD_SIZE=2
export MAX_NUM_SEQS=16
mpirun -np $NUM_GPUS --prepend-rank \
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000

View file

@ -25,6 +25,12 @@ from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
import asyncio, uuid
from typing import Dict, List, Optional
from transformers.utils import logging
logger = logging.get_logger(__name__)
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
@ -35,6 +41,7 @@ def get_int_from_env(env_keys, default):
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
@ -70,7 +77,7 @@ def load_model(model_path, low_bit):
model = deepspeed.init_inference(
model,
mp_size=world_size,
tensor_parallel={"tp_size": world_size},
dtype=torch.bfloat16,
replace_method="auto",
)
@ -96,11 +103,22 @@ def load_model(model_path, low_bit):
init_distributed()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt: str, n_predict: int = 32):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
def generate_text(prompt: List[str], n_predict = 32):
while prompt[-1] == "":
prompt = prompt[:-1]
if isinstance(n_predict, list):
n_predict = max(n_predict)
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
# print(input_ids)
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
output = model.generate(input_ids,
attention_mask=attention_mask,
max_new_tokens=n_predict,
use_cache=True)
torch.xpu.synchronize()
@ -111,19 +129,59 @@ class PromptRequest(BaseModel):
prompt: str
n_predict: int = 32
empty_req = PromptRequest(prompt="", n_predict=0)
app = FastAPI()
request_queue: asyncio.Queue = asyncio.Queue()
result_dict: Dict[str, str] = {}
@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
if local_rank == 0:
object_list = [prompt_request]
dist.broadcast_object_list(object_list, src=0)
start_time = time.time()
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
generate_time = time.time() - start_time
output = output.cpu()
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'}
request_id = str(uuid.uuid4())
await request_queue.put((request_id, prompt_request))
while True:
await asyncio.sleep(0.1)
if request_id in result_dict:
output_str = result_dict.pop(request_id)
return {"generated_text": output_str}
async def process_requests():
while True:
request_ids, prompt_requests = [], []
for _ in range(max_num_seqs):
if request_queue.empty():
break
request_id, prompt_request = await request_queue.get()
request_ids.append(request_id)
prompt_requests.append(prompt_request)
if local_rank == 0 and prompt_requests:
# import pdb
# pdb.set_trace()
object_list = prompt_requests
if len(object_list) < max_num_seqs:
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
dist.broadcast_object_list(object_list, src=0)
start_time = time.time()
outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
generate_time = time.time() - start_time
outputs = outputs.cpu()
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
output_strs = output_strs[:len(prompt_requests)]
for request_id, output_str in zip(request_ids, output_strs):
result_dict[request_id] = output_str
await asyncio.sleep(0.1)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
@ -143,7 +201,6 @@ if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=args.port)
else:
while True:
object_list = [None]
object_list = [None] * max_num_seqs
dist.broadcast_object_list(object_list, src=0)
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])