LLM: Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example (#10876)
Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example.
This commit is contained in:
parent
3e8ed54270
commit
3d4950b0f0
2 changed files with 76 additions and 16 deletions
|
|
@ -30,6 +30,9 @@ export USE_XETLA=OFF
|
|||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||
export TORCH_LLM_ALLREDUCE=0
|
||||
|
||||
export WORLD_SIZE=2
|
||||
export MAX_NUM_SEQS=16
|
||||
|
||||
mpirun -np $NUM_GPUS --prepend-rank \
|
||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,12 @@ from fastapi import FastAPI, HTTPException
|
|||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
import asyncio, uuid
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
for e in env_keys:
|
||||
|
|
@ -35,6 +41,7 @@ def get_int_from_env(env_keys, default):
|
|||
|
||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||
|
|
@ -70,7 +77,7 @@ def load_model(model_path, low_bit):
|
|||
|
||||
model = deepspeed.init_inference(
|
||||
model,
|
||||
mp_size=world_size,
|
||||
tensor_parallel={"tp_size": world_size},
|
||||
dtype=torch.bfloat16,
|
||||
replace_method="auto",
|
||||
)
|
||||
|
|
@ -96,11 +103,22 @@ def load_model(model_path, low_bit):
|
|||
init_distributed()
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def generate_text(prompt: str, n_predict: int = 32):
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
|
||||
def generate_text(prompt: List[str], n_predict = 32):
|
||||
while prompt[-1] == "":
|
||||
prompt = prompt[:-1]
|
||||
if isinstance(n_predict, list):
|
||||
n_predict = max(n_predict)
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
||||
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
|
||||
# print(input_ids)
|
||||
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
|
||||
output = model.generate(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_new_tokens=n_predict,
|
||||
use_cache=True)
|
||||
torch.xpu.synchronize()
|
||||
|
|
@ -111,19 +129,59 @@ class PromptRequest(BaseModel):
|
|||
prompt: str
|
||||
n_predict: int = 32
|
||||
|
||||
empty_req = PromptRequest(prompt="", n_predict=0)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
request_queue: asyncio.Queue = asyncio.Queue()
|
||||
result_dict: Dict[str, str] = {}
|
||||
|
||||
@app.post("/generate/")
|
||||
async def generate(prompt_request: PromptRequest):
|
||||
if local_rank == 0:
|
||||
object_list = [prompt_request]
|
||||
dist.broadcast_object_list(object_list, src=0)
|
||||
start_time = time.time()
|
||||
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
|
||||
generate_time = time.time() - start_time
|
||||
output = output.cpu()
|
||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'}
|
||||
request_id = str(uuid.uuid4())
|
||||
await request_queue.put((request_id, prompt_request))
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
if request_id in result_dict:
|
||||
output_str = result_dict.pop(request_id)
|
||||
return {"generated_text": output_str}
|
||||
|
||||
|
||||
async def process_requests():
|
||||
while True:
|
||||
request_ids, prompt_requests = [], []
|
||||
for _ in range(max_num_seqs):
|
||||
if request_queue.empty():
|
||||
break
|
||||
request_id, prompt_request = await request_queue.get()
|
||||
request_ids.append(request_id)
|
||||
prompt_requests.append(prompt_request)
|
||||
|
||||
if local_rank == 0 and prompt_requests:
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
object_list = prompt_requests
|
||||
if len(object_list) < max_num_seqs:
|
||||
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
||||
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
|
||||
dist.broadcast_object_list(object_list, src=0)
|
||||
start_time = time.time()
|
||||
outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
||||
generate_time = time.time() - start_time
|
||||
outputs = outputs.cpu()
|
||||
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
output_strs = output_strs[:len(prompt_requests)]
|
||||
|
||||
for request_id, output_str in zip(request_ids, output_strs):
|
||||
result_dict[request_id] = output_str
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
asyncio.create_task(process_requests())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
|
||||
|
|
@ -143,7 +201,6 @@ if __name__ == "__main__":
|
|||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
else:
|
||||
while True:
|
||||
object_list = [None]
|
||||
object_list = [None] * max_num_seqs
|
||||
dist.broadcast_object_list(object_list, src=0)
|
||||
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
|
||||
|
||||
output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
||||
|
|
|
|||
Loading…
Reference in a new issue