LLM: Refine Deepspped-AutoTP-FastAPI example (#10916)
This commit is contained in:
parent
1de878bee1
commit
13a44cdacb
3 changed files with 4842 additions and 13 deletions
4791
python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/benchmark_util.py
Normal file
4791
python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/benchmark_util.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -31,8 +31,7 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
|||
export TORCH_LLM_ALLREDUCE=0
|
||||
|
||||
export WORLD_SIZE=2
|
||||
export MAX_NUM_SEQS=16
|
||||
|
||||
mpirun -np $NUM_GPUS --prepend-rank \
|
||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
|
||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ from typing import Dict, List, Optional
|
|||
from transformers.utils import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
from benchmark_util import BenchmarkWrapper
|
||||
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
for e in env_keys:
|
||||
|
|
@ -39,9 +42,11 @@ def get_int_from_env(env_keys, default):
|
|||
return val
|
||||
return int(default)
|
||||
|
||||
global max_num_seqs
|
||||
global max_num_batched_tokens
|
||||
|
||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
|
||||
os.environ["RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||
|
|
@ -92,6 +97,7 @@ def load_model(model_path, low_bit):
|
|||
|
||||
# Move model back to xpu
|
||||
model = model.to(f'xpu:{local_rank}')
|
||||
model = BenchmarkWrapper(model)
|
||||
|
||||
# Modify backend related settings
|
||||
if world_size > 1:
|
||||
|
|
@ -133,6 +139,8 @@ empty_req = PromptRequest(prompt="", n_predict=0)
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
from collections import deque
|
||||
rest_req_deque = deque(maxlen=128)
|
||||
request_queue: asyncio.Queue = asyncio.Queue()
|
||||
result_dict: Dict[str, str] = {}
|
||||
|
||||
|
|
@ -150,16 +158,39 @@ async def generate(prompt_request: PromptRequest):
|
|||
async def process_requests():
|
||||
while True:
|
||||
request_ids, prompt_requests = [], []
|
||||
for _ in range(max_num_seqs):
|
||||
if request_queue.empty():
|
||||
break
|
||||
request_id, prompt_request = await request_queue.get()
|
||||
request_ids.append(request_id)
|
||||
prompt_requests.append(prompt_request)
|
||||
cur_batched_tokens = 0
|
||||
|
||||
if local_rank == 0:
|
||||
while rest_req_deque:
|
||||
request_id, rest_request = rest_req_deque.popleft()
|
||||
prompt = rest_request.prompt
|
||||
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
||||
cur_batched_tokens += cur_prompt_len
|
||||
if cur_batched_tokens > max_num_batched_tokens:
|
||||
cur_batched_tokens -= cur_prompt_len
|
||||
rest_req_deque.appendleft((request_id, rest_request))
|
||||
break
|
||||
request_ids.append(request_id)
|
||||
prompt_requests.append(rest_request)
|
||||
if len(prompt_requests) == max_num_seqs:
|
||||
break
|
||||
|
||||
for _ in range(max_num_seqs - len(prompt_requests)):
|
||||
if request_queue.empty():
|
||||
break
|
||||
request_id, prompt_request = await request_queue.get()
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
||||
cur_batched_tokens += cur_prompt_len
|
||||
if cur_batched_tokens > max_num_batched_tokens:
|
||||
cur_batched_tokens -= cur_prompt_len
|
||||
rest_req_deque.appendleft((request_id, prompt_request))
|
||||
break
|
||||
request_ids.append(request_id)
|
||||
prompt_requests.append(prompt_request)
|
||||
|
||||
if local_rank == 0 and prompt_requests:
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
object_list = prompt_requests
|
||||
if len(object_list) < max_num_seqs:
|
||||
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
||||
|
|
@ -174,13 +205,15 @@ async def process_requests():
|
|||
|
||||
for request_id, output_str in zip(request_ids, output_strs):
|
||||
result_dict[request_id] = output_str
|
||||
logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
asyncio.create_task(process_requests())
|
||||
if local_rank == 0:
|
||||
asyncio.create_task(process_requests())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -192,10 +225,16 @@ if __name__ == "__main__":
|
|||
help='The quantization type the model will convert to.')
|
||||
parser.add_argument('--port', type=int, default=8000,
|
||||
help='The port number on which the server will run.')
|
||||
|
||||
parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
|
||||
help='Max tokens can be batched by this service.')
|
||||
parser.add_argument('--max-num-seqs', type=int, default=8,
|
||||
help='Max requests can be batched by this service.')
|
||||
|
||||
args = parser.parse_args()
|
||||
model_path = args.repo_id_or_model_path
|
||||
low_bit = args.low_bit
|
||||
max_num_seqs = args.max_num_seqs
|
||||
max_num_batched_tokens = args.max_num_batched_tokens
|
||||
load_model(model_path, low_bit)
|
||||
if local_rank == 0:
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
|
|
|
|||
Loading…
Reference in a new issue