LLM: Refine Deepspped-AutoTP-FastAPI example (#10916)

This commit is contained in:
Xiangyu Tian 2024-05-07 09:37:31 +08:00 committed by GitHub
parent 1de878bee1
commit 13a44cdacb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 4842 additions and 13 deletions

File diff suppressed because it is too large Load diff

View file

@ -31,8 +31,7 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0
export WORLD_SIZE=2
export MAX_NUM_SEQS=16
mpirun -np $NUM_GPUS --prepend-rank \
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192

View file

@ -31,6 +31,9 @@ from typing import Dict, List, Optional
from transformers.utils import logging
logger = logging.get_logger(__name__)
from benchmark_util import BenchmarkWrapper
def get_int_from_env(env_keys, default):
"""Returns the first positive env value found in the `env_keys` list or the default."""
for e in env_keys:
@ -39,9 +42,11 @@ def get_int_from_env(env_keys, default):
return val
return int(default)
global max_num_seqs
global max_num_batched_tokens
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
@ -92,6 +97,7 @@ def load_model(model_path, low_bit):
# Move model back to xpu
model = model.to(f'xpu:{local_rank}')
model = BenchmarkWrapper(model)
# Modify backend related settings
if world_size > 1:
@ -133,6 +139,8 @@ empty_req = PromptRequest(prompt="", n_predict=0)
app = FastAPI()
from collections import deque
rest_req_deque = deque(maxlen=128)
request_queue: asyncio.Queue = asyncio.Queue()
result_dict: Dict[str, str] = {}
@ -150,16 +158,39 @@ async def generate(prompt_request: PromptRequest):
async def process_requests():
while True:
request_ids, prompt_requests = [], []
for _ in range(max_num_seqs):
if request_queue.empty():
break
request_id, prompt_request = await request_queue.get()
request_ids.append(request_id)
prompt_requests.append(prompt_request)
cur_batched_tokens = 0
if local_rank == 0:
while rest_req_deque:
request_id, rest_request = rest_req_deque.popleft()
prompt = rest_request.prompt
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
cur_batched_tokens += cur_prompt_len
if cur_batched_tokens > max_num_batched_tokens:
cur_batched_tokens -= cur_prompt_len
rest_req_deque.appendleft((request_id, rest_request))
break
request_ids.append(request_id)
prompt_requests.append(rest_request)
if len(prompt_requests) == max_num_seqs:
break
for _ in range(max_num_seqs - len(prompt_requests)):
if request_queue.empty():
break
request_id, prompt_request = await request_queue.get()
# import pdb
# pdb.set_trace()
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
cur_batched_tokens += cur_prompt_len
if cur_batched_tokens > max_num_batched_tokens:
cur_batched_tokens -= cur_prompt_len
rest_req_deque.appendleft((request_id, prompt_request))
break
request_ids.append(request_id)
prompt_requests.append(prompt_request)
if local_rank == 0 and prompt_requests:
# import pdb
# pdb.set_trace()
object_list = prompt_requests
if len(object_list) < max_num_seqs:
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
@ -174,13 +205,15 @@ async def process_requests():
for request_id, output_str in zip(request_ids, output_strs):
result_dict[request_id] = output_str
logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
await asyncio.sleep(0.1)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests())
if local_rank == 0:
asyncio.create_task(process_requests())
if __name__ == "__main__":
@ -192,10 +225,16 @@ if __name__ == "__main__":
help='The quantization type the model will convert to.')
parser.add_argument('--port', type=int, default=8000,
help='The port number on which the server will run.')
parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
help='Max tokens can be batched by this service.')
parser.add_argument('--max-num-seqs', type=int, default=8,
help='Max requests can be batched by this service.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
max_num_seqs = args.max_num_seqs
max_num_batched_tokens = args.max_num_batched_tokens
load_model(model_path, low_bit)
if local_rank == 0:
uvicorn.run(app, host="0.0.0.0", port=args.port)