From 23681fbf5c215c12440736f977694ca6e62efdc8 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Fri, 26 Jul 2024 09:41:03 +0800 Subject: [PATCH] Support codegeex4-9b for lightweight-serving (#11648) * add options, support prompt and not return end_token * enable openai parameter * set do_sample None and update style --- .../example/GPU/Lightweight-Serving/README.md | 7 +- .../ipex_llm/serving/fastapi/api_server.py | 226 +++++++++++------- .../ipex_llm/serving/fastapi/model_worker.py | 5 + 3 files changed, 149 insertions(+), 89 deletions(-) diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md index 791bf166..104032c6 100644 --- a/python/llm/example/GPU/Lightweight-Serving/README.md +++ b/python/llm/example/GPU/Lightweight-Serving/README.md @@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Llama-2-7b-chat-hf", - "messages": [{"role": "user", "content": "Hello! What is your name?"}] + "messages": [{"role": "user", "content": "Hello! What is your name?"}], + "stream": false }' ``` #### /v1/completions ```bash + curl http://localhost:8000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "model": "Llama-2-7b-chat-hf", "prompt": "Once upon a time", - "max_tokens": 32 + "max_tokens": 32, + "stream": false }' ``` diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py index 75bd49d0..5109c822 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py +++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py @@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError import asyncio import uuid from typing import List, Optional, Union, Dict +from fastapi.middleware.cors import CORSMiddleware from .tgi_protocol import Parameters - - -result_dict: Dict[str, str] = {} -logger = logging.get_logger(__name__) - - -class InputsRequest(BaseModel): - inputs: str - parameters: Optional[Parameters] = None - stream: Optional[bool] = False - req_type: str = 'completion' - - -class ChatCompletionRequest(BaseModel): - messages: List[ChatCompletionMessageParam] - model: str - max_tokens: Optional[int] = None - stream: Optional[bool] = False - - -class CompletionRequest(BaseModel): - model: str - prompt: Union[List[int], List[List[int]], str, List[str]] - max_tokens: Optional[int] = None - stream: Optional[bool] = False - - -app = FastAPI() -global tokenizer -global local_model - - -class FastApp(): - def __init__(self, model, mytokenizer): - global tokenizer - global local_model - local_model = model - tokenizer = mytokenizer - self.app = app - - from .openai_protocol import ( ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, @@ -80,6 +40,63 @@ from .openai_protocol import ( CompletionStreamResponse, ) +result_dict: Dict[str, str] = {} +logger = logging.get_logger(__name__) + + +class InputsRequest(BaseModel): + inputs: str + parameters: Optional[Parameters] = None + stream: Optional[bool] = False + req_type: str = 'completion' + + +class ChatCompletionRequest(BaseModel): + messages: List[ChatMessage] + model: str + max_tokens: Optional[int] = None + min_tokens: Optional[int] = None + stream: Optional[bool] = False + top_p: Optional[float] = None + top_k: Optional[int] = None + repetition_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + temperature: Optional[float] = None + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + max_tokens: Optional[int] = None + min_tokens: Optional[int] = None + stream: Optional[bool] = False + top_p: Optional[float] = None + top_k: Optional[int] = None + repetition_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + temperature: Optional[float] = None + + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + +global tokenizer +global local_model + + +class FastApp(): + def __init__(self, model, mytokenizer): + global tokenizer + global local_model + local_model = model + tokenizer = mytokenizer + self.app = app + def get_queue_next_token(delta_text_queue): timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60)) @@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue): remain = 1 return delta_text, remain + +def should_return_end_token(next_token): + if "codegeex" not in local_model.model_name.lower(): + return True + else: + if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]: + return False + return True + async def chat_stream_generator(local_model, delta_text_queue, request_id): model_name = local_model.model_name index = 0 @@ -104,18 +130,19 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id): await asyncio.sleep(0) continue if remain == 0 and delta_text is not None or remain != 0: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(role="assistant", content=delta_text), - logprobs=None, - finish_reason=None) - chunk = ChatCompletionStreamResponse( - id=request_id, - choices=[choice_data], - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - index = index + 1 + if should_return_end_token(delta_text): + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role="assistant", content=delta_text), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + index = index + 1 if remain == 0: choice_data = ChatCompletionResponseStreamChoice( index=index, @@ -146,18 +173,19 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id) await asyncio.sleep(0) continue if remain == 0 and delta_text is not None or remain != 0: - choice_data = CompletionResponseStreamChoice( - index=index, - text=delta_text, - logprobs=None, - finish_reason=None) - chunk = CompletionStreamResponse( - id=request_id, - choices=[choice_data], - model=model_name) - data = chunk.model_dump_json(exclude_unset=True) - yield f"data: {data}\n\n" - index = index + 1 + if should_return_end_token(delta_text): + choice_data = CompletionResponseStreamChoice( + index=index, + text=delta_text, + logprobs=None, + finish_reason=None) + chunk = CompletionStreamResponse( + id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + index = index + 1 if remain == 0: choice_data = CompletionResponseStreamChoice( index=index, @@ -237,31 +265,59 @@ async def generate_stream(inputs_request: InputsRequest): def get_prompt(messages) -> str: - prompt = "" - for msg in messages: - role = msg["role"] - content = msg["content"] - if role == "system": - prompt += f"<>\n{content}\n<>\n\n" - elif role == "user": - prompt += f"[INST] {content} [/INST] " - elif role == "assistant": - prompt += f"{content} " + if "codegeex" in local_model.model_name.lower(): + query = messages[-1].content + if len(messages) <= 1: + history = [] else: - invalidInputError(False, f"Unknown role: {role}") - return prompt.strip() + history = [msg.model_dump() for msg in messages[:-1]] + history.append({"role": "user", "content": query}) + inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False, + return_tensors="pt", return_dict=False) + return inputs + else: + prompt = "" + for msg in messages: + role = msg.role + content = msg.content + if role == "system": + prompt += f"<>\n{content}\n<>\n\n" + elif role == "user": + prompt += f"[INST] {content} [/INST] " + elif role == "assistant": + prompt += f"{content} " + else: + invalidInputError(False, f"Unknown role: {role}") + return prompt.strip() + + +def set_parameters(req): + if req.max_tokens is None: + n_predict = 256 + else: + n_predict = req.max_tokens + if req.repetition_penalty is not None: + repetition_penalty = req.repetition_penalty + elif req.presence_penalty is not None: + repetition_penalty = req.presence_penalty + else: + repetition_penalty = None + if req.temperature is not None and req.temperature > 1e-4: + do_sample = True + else: + do_sample = False + return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens, + top_p=req.top_p, repetition_penalty=repetition_penalty, + temperature=req.temperature, top_k=req.top_k) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): + print(request) model_name = local_model.model_name - if request.max_tokens is None: - n_predict = 256 - else: - n_predict = request.max_tokens inputs_request = InputsRequest( inputs=get_prompt(request.messages), - parameters=Parameters(max_new_tokens=n_predict), + parameters=set_parameters(request), stream=request.stream, req_type="chat" ) @@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest): @app.post("/v1/completions") async def create_completion(request: CompletionRequest): model_name = local_model.model_name - if request.max_tokens is None: - n_predict = 32 - else: - n_predict = request.max_tokens inputs_request = InputsRequest( inputs=request.prompt, - parameters=Parameters(max_new_tokens=n_predict), + parameters=set_parameters(request), stream=request.stream, req_type="completion" ) diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py index c88e1434..099cfc0c 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py +++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py @@ -73,6 +73,11 @@ class ModelWorker: def model_generate(): generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} + if "codegeex" in self.model_name.lower(): + eos_token_id = [tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|user|>"), + tokenizer.convert_tokens_to_ids("<|observation|>")] + generate_kwargs["eos_token_id"] = eos_token_id self.model.generate(input_ids, streamer=self.streamer[request_id], **generate_kwargs) torch.xpu.empty_cache()