LLM: Refine Deepspped-AutoTP-FastAPI example (#10916)
This commit is contained in:
parent
1de878bee1
commit
13a44cdacb
3 changed files with 4842 additions and 13 deletions
4791
python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/benchmark_util.py
Normal file
4791
python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/benchmark_util.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -31,8 +31,7 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||||
export TORCH_LLM_ALLREDUCE=0
|
export TORCH_LLM_ALLREDUCE=0
|
||||||
|
|
||||||
export WORLD_SIZE=2
|
export WORLD_SIZE=2
|
||||||
export MAX_NUM_SEQS=16
|
|
||||||
|
|
||||||
mpirun -np $NUM_GPUS --prepend-rank \
|
mpirun -np $NUM_GPUS --prepend-rank \
|
||||||
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
|
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,9 @@ from typing import Dict, List, Optional
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
from benchmark_util import BenchmarkWrapper
|
||||||
|
|
||||||
|
|
||||||
def get_int_from_env(env_keys, default):
|
def get_int_from_env(env_keys, default):
|
||||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||||
for e in env_keys:
|
for e in env_keys:
|
||||||
|
|
@ -39,9 +42,11 @@ def get_int_from_env(env_keys, default):
|
||||||
return val
|
return val
|
||||||
return int(default)
|
return int(default)
|
||||||
|
|
||||||
|
global max_num_seqs
|
||||||
|
global max_num_batched_tokens
|
||||||
|
|
||||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
|
||||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
|
||||||
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
|
|
||||||
os.environ["RANK"] = str(local_rank)
|
os.environ["RANK"] = str(local_rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
|
||||||
|
|
@ -92,6 +97,7 @@ def load_model(model_path, low_bit):
|
||||||
|
|
||||||
# Move model back to xpu
|
# Move model back to xpu
|
||||||
model = model.to(f'xpu:{local_rank}')
|
model = model.to(f'xpu:{local_rank}')
|
||||||
|
model = BenchmarkWrapper(model)
|
||||||
|
|
||||||
# Modify backend related settings
|
# Modify backend related settings
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
|
@ -133,6 +139,8 @@ empty_req = PromptRequest(prompt="", n_predict=0)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
rest_req_deque = deque(maxlen=128)
|
||||||
request_queue: asyncio.Queue = asyncio.Queue()
|
request_queue: asyncio.Queue = asyncio.Queue()
|
||||||
result_dict: Dict[str, str] = {}
|
result_dict: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
@ -150,16 +158,39 @@ async def generate(prompt_request: PromptRequest):
|
||||||
async def process_requests():
|
async def process_requests():
|
||||||
while True:
|
while True:
|
||||||
request_ids, prompt_requests = [], []
|
request_ids, prompt_requests = [], []
|
||||||
for _ in range(max_num_seqs):
|
cur_batched_tokens = 0
|
||||||
|
|
||||||
|
if local_rank == 0:
|
||||||
|
while rest_req_deque:
|
||||||
|
request_id, rest_request = rest_req_deque.popleft()
|
||||||
|
prompt = rest_request.prompt
|
||||||
|
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
||||||
|
cur_batched_tokens += cur_prompt_len
|
||||||
|
if cur_batched_tokens > max_num_batched_tokens:
|
||||||
|
cur_batched_tokens -= cur_prompt_len
|
||||||
|
rest_req_deque.appendleft((request_id, rest_request))
|
||||||
|
break
|
||||||
|
request_ids.append(request_id)
|
||||||
|
prompt_requests.append(rest_request)
|
||||||
|
if len(prompt_requests) == max_num_seqs:
|
||||||
|
break
|
||||||
|
|
||||||
|
for _ in range(max_num_seqs - len(prompt_requests)):
|
||||||
if request_queue.empty():
|
if request_queue.empty():
|
||||||
break
|
break
|
||||||
request_id, prompt_request = await request_queue.get()
|
request_id, prompt_request = await request_queue.get()
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
|
||||||
|
cur_batched_tokens += cur_prompt_len
|
||||||
|
if cur_batched_tokens > max_num_batched_tokens:
|
||||||
|
cur_batched_tokens -= cur_prompt_len
|
||||||
|
rest_req_deque.appendleft((request_id, prompt_request))
|
||||||
|
break
|
||||||
request_ids.append(request_id)
|
request_ids.append(request_id)
|
||||||
prompt_requests.append(prompt_request)
|
prompt_requests.append(prompt_request)
|
||||||
|
|
||||||
if local_rank == 0 and prompt_requests:
|
if local_rank == 0 and prompt_requests:
|
||||||
# import pdb
|
|
||||||
# pdb.set_trace()
|
|
||||||
object_list = prompt_requests
|
object_list = prompt_requests
|
||||||
if len(object_list) < max_num_seqs:
|
if len(object_list) < max_num_seqs:
|
||||||
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
|
||||||
|
|
@ -174,12 +205,14 @@ async def process_requests():
|
||||||
|
|
||||||
for request_id, output_str in zip(request_ids, output_strs):
|
for request_id, output_str in zip(request_ids, output_strs):
|
||||||
result_dict[request_id] = output_str
|
result_dict[request_id] = output_str
|
||||||
|
logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
|
||||||
|
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
|
if local_rank == 0:
|
||||||
asyncio.create_task(process_requests())
|
asyncio.create_task(process_requests())
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -192,10 +225,16 @@ if __name__ == "__main__":
|
||||||
help='The quantization type the model will convert to.')
|
help='The quantization type the model will convert to.')
|
||||||
parser.add_argument('--port', type=int, default=8000,
|
parser.add_argument('--port', type=int, default=8000,
|
||||||
help='The port number on which the server will run.')
|
help='The port number on which the server will run.')
|
||||||
|
parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
|
||||||
|
help='Max tokens can be batched by this service.')
|
||||||
|
parser.add_argument('--max-num-seqs', type=int, default=8,
|
||||||
|
help='Max requests can be batched by this service.')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
model_path = args.repo_id_or_model_path
|
model_path = args.repo_id_or_model_path
|
||||||
low_bit = args.low_bit
|
low_bit = args.low_bit
|
||||||
|
max_num_seqs = args.max_num_seqs
|
||||||
|
max_num_batched_tokens = args.max_num_batched_tokens
|
||||||
load_model(model_path, low_bit)
|
load_model(model_path, low_bit)
|
||||||
if local_rank == 0:
|
if local_rank == 0:
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue