LLM: Refine Deepspped-AutoTP-FastAPI example (#10916)
This commit is contained in:
		
							parent
							
								
									1de878bee1
								
							
						
					
					
						commit
						13a44cdacb
					
				
					 3 changed files with 4842 additions and 13 deletions
				
			
		
							
								
								
									
										4791
									
								
								python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/benchmark_util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4791
									
								
								python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/benchmark_util.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							| 
						 | 
				
			
			@ -31,8 +31,7 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
			
		|||
export TORCH_LLM_ALLREDUCE=0
 | 
			
		||||
 | 
			
		||||
export WORLD_SIZE=2
 | 
			
		||||
export MAX_NUM_SEQS=16
 | 
			
		||||
 | 
			
		||||
mpirun -np $NUM_GPUS --prepend-rank \
 | 
			
		||||
        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
 | 
			
		||||
        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,9 @@ from typing import Dict, List, Optional
 | 
			
		|||
from transformers.utils import logging
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
from benchmark_util import BenchmarkWrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_int_from_env(env_keys, default):
 | 
			
		||||
    """Returns the first positive env value found in the `env_keys` list or the default."""
 | 
			
		||||
    for e in env_keys:
 | 
			
		||||
| 
						 | 
				
			
			@ -39,9 +42,11 @@ def get_int_from_env(env_keys, default):
 | 
			
		|||
            return val
 | 
			
		||||
    return int(default)
 | 
			
		||||
 | 
			
		||||
global max_num_seqs
 | 
			
		||||
global max_num_batched_tokens
 | 
			
		||||
 | 
			
		||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
 | 
			
		||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
 | 
			
		||||
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
 | 
			
		||||
os.environ["RANK"] = str(local_rank)
 | 
			
		||||
os.environ["WORLD_SIZE"] = str(world_size)
 | 
			
		||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
			
		||||
| 
						 | 
				
			
			@ -92,6 +97,7 @@ def load_model(model_path, low_bit):
 | 
			
		|||
 | 
			
		||||
    # Move model back to xpu
 | 
			
		||||
    model = model.to(f'xpu:{local_rank}')
 | 
			
		||||
    model = BenchmarkWrapper(model)
 | 
			
		||||
 | 
			
		||||
    # Modify backend related settings 
 | 
			
		||||
    if world_size > 1:
 | 
			
		||||
| 
						 | 
				
			
			@ -133,6 +139,8 @@ empty_req = PromptRequest(prompt="", n_predict=0)
 | 
			
		|||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
from collections import deque
 | 
			
		||||
rest_req_deque = deque(maxlen=128)
 | 
			
		||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
			
		||||
result_dict: Dict[str, str] = {}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -150,16 +158,39 @@ async def generate(prompt_request: PromptRequest):
 | 
			
		|||
async def process_requests():
 | 
			
		||||
    while True:
 | 
			
		||||
        request_ids, prompt_requests = [], []
 | 
			
		||||
        for _ in range(max_num_seqs):
 | 
			
		||||
            if request_queue.empty():
 | 
			
		||||
                break
 | 
			
		||||
            request_id, prompt_request = await request_queue.get()
 | 
			
		||||
            request_ids.append(request_id)
 | 
			
		||||
            prompt_requests.append(prompt_request)
 | 
			
		||||
        cur_batched_tokens = 0
 | 
			
		||||
 | 
			
		||||
        if local_rank == 0:
 | 
			
		||||
            while rest_req_deque:
 | 
			
		||||
                request_id, rest_request = rest_req_deque.popleft()
 | 
			
		||||
                prompt = rest_request.prompt
 | 
			
		||||
                cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
 | 
			
		||||
                cur_batched_tokens += cur_prompt_len
 | 
			
		||||
                if cur_batched_tokens > max_num_batched_tokens:
 | 
			
		||||
                    cur_batched_tokens -= cur_prompt_len
 | 
			
		||||
                    rest_req_deque.appendleft((request_id, rest_request))
 | 
			
		||||
                    break
 | 
			
		||||
                request_ids.append(request_id)
 | 
			
		||||
                prompt_requests.append(rest_request)
 | 
			
		||||
                if len(prompt_requests) == max_num_seqs:
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
            for _ in range(max_num_seqs - len(prompt_requests)):
 | 
			
		||||
                if request_queue.empty():
 | 
			
		||||
                    break
 | 
			
		||||
                request_id, prompt_request = await request_queue.get()
 | 
			
		||||
                # import pdb
 | 
			
		||||
                # pdb.set_trace()
 | 
			
		||||
                cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1)
 | 
			
		||||
                cur_batched_tokens += cur_prompt_len
 | 
			
		||||
                if cur_batched_tokens > max_num_batched_tokens:
 | 
			
		||||
                    cur_batched_tokens -= cur_prompt_len
 | 
			
		||||
                    rest_req_deque.appendleft((request_id, prompt_request))
 | 
			
		||||
                    break
 | 
			
		||||
                request_ids.append(request_id)
 | 
			
		||||
                prompt_requests.append(prompt_request)
 | 
			
		||||
 | 
			
		||||
        if local_rank == 0 and prompt_requests:
 | 
			
		||||
            # import pdb
 | 
			
		||||
            # pdb.set_trace()
 | 
			
		||||
            object_list = prompt_requests
 | 
			
		||||
            if len(object_list) < max_num_seqs:
 | 
			
		||||
                object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
 | 
			
		||||
| 
						 | 
				
			
			@ -174,13 +205,15 @@ async def process_requests():
 | 
			
		|||
 | 
			
		||||
            for request_id, output_str in zip(request_ids, output_strs):
 | 
			
		||||
                result_dict[request_id] = output_str
 | 
			
		||||
            logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}")
 | 
			
		||||
 | 
			
		||||
        await asyncio.sleep(0.1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
async def startup_event():
 | 
			
		||||
    asyncio.create_task(process_requests())
 | 
			
		||||
    if local_rank == 0:
 | 
			
		||||
        asyncio.create_task(process_requests())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
| 
						 | 
				
			
			@ -192,10 +225,16 @@ if __name__ == "__main__":
 | 
			
		|||
                    help='The quantization type the model will convert to.')
 | 
			
		||||
    parser.add_argument('--port', type=int, default=8000,
 | 
			
		||||
                    help='The port number on which the server will run.')
 | 
			
		||||
    parser.add_argument('--max-num-batched-tokens', type=int, default=4096,
 | 
			
		||||
                    help='Max tokens can be batched by this service.')
 | 
			
		||||
    parser.add_argument('--max-num-seqs', type=int, default=8,
 | 
			
		||||
                    help='Max requests can be batched by this service.')
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    low_bit = args.low_bit
 | 
			
		||||
    max_num_seqs = args.max_num_seqs
 | 
			
		||||
    max_num_batched_tokens = args.max_num_batched_tokens
 | 
			
		||||
    load_model(model_path, low_bit)
 | 
			
		||||
    if local_rank == 0:
 | 
			
		||||
        uvicorn.run(app, host="0.0.0.0", port=args.port)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue