LLM: Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example (#10876)
Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example.
This commit is contained in:
parent
3e8ed54270
commit
3d4950b0f0
2 changed files with 76 additions and 16 deletions
|
|
@ -30,6 +30,9 @@ export USE_XETLA=OFF
|
||||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||||
export TORCH_LLM_ALLREDUCE=0
|
export TORCH_LLM_ALLREDUCE=0
|
||||||
|
|
||||||
|
export WORLD_SIZE=2
|
||||||
|
export MAX_NUM_SEQS=16
|
||||||
|
|
||||||
mpirun -np $NUM_GPUS --prepend-rank \
|
mpirun -np $NUM_GPUS --prepend-rank \
|
||||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
|
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,12 @@ from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
import asyncio, uuid
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from transformers.utils import logging
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
def get_int_from_env(env_keys, default):
|
def get_int_from_env(env_keys, default):
|
||||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||||
for e in env_keys:
|
for e in env_keys:
|
||||||
|
|
@ -35,6 +41,7 @@ def get_int_from_env(env_keys, default):
|
||||||
|
|
||||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
||||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||||
|
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
|
||||||
os.environ["RANK"] = str(local_rank)
|
os.environ["RANK"] = str(local_rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||||
|
|
@ -70,7 +77,7 @@ def load_model(model_path, low_bit):
|
||||||
|
|
||||||
model = deepspeed.init_inference(
|
model = deepspeed.init_inference(
|
||||||
model,
|
model,
|
||||||
mp_size=world_size,
|
tensor_parallel={"tp_size": world_size},
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
replace_method="auto",
|
replace_method="auto",
|
||||||
)
|
)
|
||||||
|
|
@ -96,11 +103,22 @@ def load_model(model_path, low_bit):
|
||||||
init_distributed()
|
init_distributed()
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
def generate_text(prompt: str, n_predict: int = 32):
|
def generate_text(prompt: List[str], n_predict = 32):
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
|
while prompt[-1] == "":
|
||||||
|
prompt = prompt[:-1]
|
||||||
|
if isinstance(n_predict, list):
|
||||||
|
n_predict = max(n_predict)
|
||||||
|
|
||||||
|
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
|
||||||
|
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
|
||||||
|
# print(input_ids)
|
||||||
|
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
|
||||||
output = model.generate(input_ids,
|
output = model.generate(input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
max_new_tokens=n_predict,
|
max_new_tokens=n_predict,
|
||||||
use_cache=True)
|
use_cache=True)
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
|
|
@ -111,19 +129,59 @@ class PromptRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
n_predict: int = 32
|
n_predict: int = 32
|
||||||
|
|
||||||
|
empty_req = PromptRequest(prompt="", n_predict=0)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
request_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
result_dict: Dict[str, str] = {}
|
||||||
|
|
||||||
@app.post("/generate/")
|
@app.post("/generate/")
|
||||||
async def generate(prompt_request: PromptRequest):
|
async def generate(prompt_request: PromptRequest):
|
||||||
if local_rank == 0:
|
request_id = str(uuid.uuid4())
|
||||||
object_list = [prompt_request]
|
await request_queue.put((request_id, prompt_request))
|
||||||
dist.broadcast_object_list(object_list, src=0)
|
while True:
|
||||||
start_time = time.time()
|
await asyncio.sleep(0.1)
|
||||||
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
|
if request_id in result_dict:
|
||||||
generate_time = time.time() - start_time
|
output_str = result_dict.pop(request_id)
|
||||||
output = output.cpu()
|
return {"generated_text": output_str}
|
||||||
output_str = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
||||||
return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'}
|
|
||||||
|
async def process_requests():
|
||||||
|
while True:
|
||||||
|
request_ids, prompt_requests = [], []
|
||||||
|
for _ in range(max_num_seqs):
|
||||||
|
if request_queue.empty():
|
||||||
|
break
|
||||||
|
request_id, prompt_request = await request_queue.get()
|
||||||
|
request_ids.append(request_id)
|
||||||
|
prompt_requests.append(prompt_request)
|
||||||
|
|
||||||
|
if local_rank == 0 and prompt_requests:
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
object_list = prompt_requests
|
||||||
|
if len(object_list) < max_num_seqs:
|
||||||
|
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
||||||
|
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
|
||||||
|
dist.broadcast_object_list(object_list, src=0)
|
||||||
|
start_time = time.time()
|
||||||
|
outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
||||||
|
generate_time = time.time() - start_time
|
||||||
|
outputs = outputs.cpu()
|
||||||
|
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||||
|
output_strs = output_strs[:len(prompt_requests)]
|
||||||
|
|
||||||
|
for request_id, output_str in zip(request_ids, output_strs):
|
||||||
|
result_dict[request_id] = output_str
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
asyncio.create_task(process_requests())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
|
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
|
||||||
|
|
@ -143,7 +201,6 @@ if __name__ == "__main__":
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
else:
|
else:
|
||||||
while True:
|
while True:
|
||||||
object_list = [None]
|
object_list = [None] * max_num_seqs
|
||||||
dist.broadcast_object_list(object_list, src=0)
|
dist.broadcast_object_list(object_list, src=0)
|
||||||
output = generate_text(object_list[0].prompt, object_list[0].n_predict)
|
output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue