LLM: Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example (#10876)
Enable batch generate (world_size>1) in Deepspeed-AutoTP-FastAPI example.
This commit is contained in:
		
							parent
							
								
									3e8ed54270
								
							
						
					
					
						commit
						3d4950b0f0
					
				
					 2 changed files with 76 additions and 16 deletions
				
			
		| 
						 | 
				
			
			@ -30,6 +30,9 @@ export USE_XETLA=OFF
 | 
			
		|||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
			
		||||
export TORCH_LLM_ALLREDUCE=0
 | 
			
		||||
 | 
			
		||||
export WORLD_SIZE=2
 | 
			
		||||
export MAX_NUM_SEQS=16
 | 
			
		||||
 | 
			
		||||
mpirun -np $NUM_GPUS --prepend-rank \
 | 
			
		||||
        python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,6 +25,12 @@ from fastapi import FastAPI, HTTPException
 | 
			
		|||
from pydantic import BaseModel
 | 
			
		||||
import uvicorn
 | 
			
		||||
 | 
			
		||||
import asyncio, uuid
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
def get_int_from_env(env_keys, default):
 | 
			
		||||
    """Returns the first positive env value found in the `env_keys` list or the default."""
 | 
			
		||||
    for e in env_keys:
 | 
			
		||||
| 
						 | 
				
			
			@ -35,6 +41,7 @@ def get_int_from_env(env_keys, default):
 | 
			
		|||
 | 
			
		||||
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0")
 | 
			
		||||
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1")
 | 
			
		||||
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
 | 
			
		||||
os.environ["RANK"] = str(local_rank)
 | 
			
		||||
os.environ["WORLD_SIZE"] = str(world_size)
 | 
			
		||||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
 | 
			
		||||
| 
						 | 
				
			
			@ -70,7 +77,7 @@ def load_model(model_path, low_bit):
 | 
			
		|||
 | 
			
		||||
    model = deepspeed.init_inference(
 | 
			
		||||
        model,
 | 
			
		||||
        mp_size=world_size,
 | 
			
		||||
        tensor_parallel={"tp_size": world_size},
 | 
			
		||||
        dtype=torch.bfloat16,
 | 
			
		||||
        replace_method="auto",
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -96,11 +103,22 @@ def load_model(model_path, low_bit):
 | 
			
		|||
    init_distributed()
 | 
			
		||||
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
			
		||||
    if tokenizer.pad_token is None:
 | 
			
		||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
 | 
			
		||||
def generate_text(prompt: str, n_predict: int = 32):
 | 
			
		||||
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(f'xpu:{local_rank}')
 | 
			
		||||
def generate_text(prompt: List[str], n_predict = 32):
 | 
			
		||||
    while prompt[-1] == "":
 | 
			
		||||
        prompt = prompt[:-1]
 | 
			
		||||
    if isinstance(n_predict, list):
 | 
			
		||||
        n_predict = max(n_predict)
 | 
			
		||||
        
 | 
			
		||||
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
 | 
			
		||||
    input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
 | 
			
		||||
    # print(input_ids)
 | 
			
		||||
    attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
 | 
			
		||||
    output = model.generate(input_ids,
 | 
			
		||||
                            attention_mask=attention_mask,
 | 
			
		||||
                            max_new_tokens=n_predict,
 | 
			
		||||
                            use_cache=True)
 | 
			
		||||
    torch.xpu.synchronize()
 | 
			
		||||
| 
						 | 
				
			
			@ -111,19 +129,59 @@ class PromptRequest(BaseModel):
 | 
			
		|||
    prompt: str
 | 
			
		||||
    n_predict: int = 32  
 | 
			
		||||
 | 
			
		||||
empty_req = PromptRequest(prompt="", n_predict=0)
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
			
		||||
result_dict: Dict[str, str] = {}
 | 
			
		||||
 | 
			
		||||
@app.post("/generate/")
 | 
			
		||||
async def generate(prompt_request: PromptRequest):
 | 
			
		||||
    if local_rank == 0:
 | 
			
		||||
        object_list = [prompt_request]
 | 
			
		||||
    request_id = str(uuid.uuid4())
 | 
			
		||||
    await request_queue.put((request_id, prompt_request))
 | 
			
		||||
    while True:
 | 
			
		||||
        await asyncio.sleep(0.1)
 | 
			
		||||
        if request_id in result_dict:
 | 
			
		||||
            output_str = result_dict.pop(request_id)
 | 
			
		||||
            return {"generated_text": output_str}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def process_requests():
 | 
			
		||||
    while True:
 | 
			
		||||
        request_ids, prompt_requests = [], []
 | 
			
		||||
        for _ in range(max_num_seqs):
 | 
			
		||||
            if request_queue.empty():
 | 
			
		||||
                break
 | 
			
		||||
            request_id, prompt_request = await request_queue.get()
 | 
			
		||||
            request_ids.append(request_id)
 | 
			
		||||
            prompt_requests.append(prompt_request)
 | 
			
		||||
 | 
			
		||||
        if local_rank == 0 and prompt_requests:
 | 
			
		||||
            # import pdb
 | 
			
		||||
            # pdb.set_trace()
 | 
			
		||||
            object_list = prompt_requests
 | 
			
		||||
            if len(object_list) < max_num_seqs:
 | 
			
		||||
                object_list = object_list + [empty_req] * (max_num_seqs - len(object_list))
 | 
			
		||||
            logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}")
 | 
			
		||||
            dist.broadcast_object_list(object_list, src=0)
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
        output = generate_text(object_list[0].prompt, object_list[0].n_predict)
 | 
			
		||||
            outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
 | 
			
		||||
            generate_time = time.time() - start_time
 | 
			
		||||
        output = output.cpu()
 | 
			
		||||
        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
 | 
			
		||||
        return {"generated_text": output_str, "generate_time": f'{generate_time:.3f}s'}
 | 
			
		||||
            outputs = outputs.cpu()
 | 
			
		||||
            output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
 | 
			
		||||
            output_strs = output_strs[:len(prompt_requests)]
 | 
			
		||||
 | 
			
		||||
            for request_id, output_str in zip(request_ids, output_strs):
 | 
			
		||||
                result_dict[request_id] = output_str
 | 
			
		||||
 | 
			
		||||
        await asyncio.sleep(0.1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
async def startup_event():
 | 
			
		||||
    asyncio.create_task(process_requests())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
 | 
			
		||||
| 
						 | 
				
			
			@ -143,7 +201,6 @@ if __name__ == "__main__":
 | 
			
		|||
        uvicorn.run(app, host="0.0.0.0", port=args.port)
 | 
			
		||||
    else:
 | 
			
		||||
        while True:
 | 
			
		||||
            object_list = [None]
 | 
			
		||||
            object_list = [None] * max_num_seqs
 | 
			
		||||
            dist.broadcast_object_list(object_list, src=0)
 | 
			
		||||
            output = generate_text(object_list[0].prompt, object_list[0].n_predict)
 | 
			
		||||
 | 
			
		||||
            output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue