From 23681fbf5c215c12440736f977694ca6e62efdc8 Mon Sep 17 00:00:00 2001
From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com>
Date: Fri, 26 Jul 2024 09:41:03 +0800
Subject: [PATCH] Support codegeex4-9b for lightweight-serving (#11648)
* add options, support prompt and not return end_token
* enable openai parameter
* set do_sample None and update style
---
 .../example/GPU/Lightweight-Serving/README.md |   7 +-
 .../ipex_llm/serving/fastapi/api_server.py    | 226 +++++++++++-------
 .../ipex_llm/serving/fastapi/model_worker.py  |   5 +
 3 files changed, 149 insertions(+), 89 deletions(-)
diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md
index 791bf166..104032c6 100644
--- a/python/llm/example/GPU/Lightweight-Serving/README.md
+++ b/python/llm/example/GPU/Lightweight-Serving/README.md
@@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \
   -H "Content-Type: application/json" \
   -d '{
     "model": "Llama-2-7b-chat-hf",
-    "messages": [{"role": "user", "content": "Hello! What is your name?"}]
+    "messages": [{"role": "user", "content": "Hello! What is your name?"}],
+    "stream": false
   }'
 ```
 
 #### /v1/completions
 
 ```bash
+
 curl http://localhost:8000/v1/completions \
   -H "Content-Type: application/json" \
   -d '{
     "model": "Llama-2-7b-chat-hf",
     "prompt": "Once upon a time",
-    "max_tokens": 32
+    "max_tokens": 32,
+    "stream": false
   }'
 ```
 
diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
index 75bd49d0..5109c822 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
@@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError
 import asyncio
 import uuid
 from typing import List, Optional, Union, Dict
+from fastapi.middleware.cors import CORSMiddleware
 from .tgi_protocol import Parameters
-
-
-result_dict: Dict[str, str] = {}
-logger = logging.get_logger(__name__)
-
-
-class InputsRequest(BaseModel):
-    inputs: str
-    parameters: Optional[Parameters] = None
-    stream: Optional[bool] = False
-    req_type: str = 'completion'
-
-
-class ChatCompletionRequest(BaseModel):
-    messages: List[ChatCompletionMessageParam]
-    model: str
-    max_tokens: Optional[int] = None
-    stream: Optional[bool] = False
-
-
-class CompletionRequest(BaseModel):
-    model: str
-    prompt: Union[List[int], List[List[int]], str, List[str]]
-    max_tokens: Optional[int] = None
-    stream: Optional[bool] = False
-
-
-app = FastAPI()
-global tokenizer
-global local_model
-
-
-class FastApp():
-    def __init__(self, model, mytokenizer):
-        global tokenizer
-        global local_model
-        local_model = model
-        tokenizer = mytokenizer
-        self.app = app
-
-
 from .openai_protocol import (
     ChatCompletionResponseStreamChoice,
     ChatCompletionStreamResponse,
@@ -80,6 +40,63 @@ from .openai_protocol import (
     CompletionStreamResponse,
 )
 
+result_dict: Dict[str, str] = {}
+logger = logging.get_logger(__name__)
+
+
+class InputsRequest(BaseModel):
+    inputs: str
+    parameters: Optional[Parameters] = None
+    stream: Optional[bool] = False
+    req_type: str = 'completion'
+
+
+class ChatCompletionRequest(BaseModel):
+    messages: List[ChatMessage]
+    model: str
+    max_tokens: Optional[int] = None
+    min_tokens: Optional[int] = None
+    stream: Optional[bool] = False
+    top_p: Optional[float] = None
+    top_k: Optional[int] = None
+    repetition_penalty: Optional[float] = None
+    presence_penalty: Optional[float] = None
+    temperature: Optional[float] = None
+
+
+class CompletionRequest(BaseModel):
+    model: str
+    prompt: Union[List[int], List[List[int]], str, List[str]]
+    max_tokens: Optional[int] = None
+    min_tokens: Optional[int] = None
+    stream: Optional[bool] = False
+    top_p: Optional[float] = None
+    top_k: Optional[int] = None
+    repetition_penalty: Optional[float] = None
+    presence_penalty: Optional[float] = None
+    temperature: Optional[float] = None
+
+
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+global tokenizer
+global local_model
+
+
+class FastApp():
+    def __init__(self, model, mytokenizer):
+        global tokenizer
+        global local_model
+        local_model = model
+        tokenizer = mytokenizer
+        self.app = app
+
 
 def get_queue_next_token(delta_text_queue):
     timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
@@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue):
         remain = 1
     return delta_text, remain
 
+
+def should_return_end_token(next_token):
+    if "codegeex" not in local_model.model_name.lower():
+        return True
+    else:
+        if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]:
+            return False
+    return True
+
 async def chat_stream_generator(local_model, delta_text_queue, request_id):
     model_name = local_model.model_name
     index = 0
@@ -104,18 +130,19 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id):
                 await asyncio.sleep(0)
                 continue
         if remain == 0 and delta_text is not None or remain != 0:
-            choice_data = ChatCompletionResponseStreamChoice(
-                index=index,
-                delta=DeltaMessage(role="assistant", content=delta_text),
-                logprobs=None,
-                finish_reason=None)
-            chunk = ChatCompletionStreamResponse(
-                id=request_id,
-                choices=[choice_data],
-                model=model_name)
-            data = chunk.model_dump_json(exclude_unset=True)
-            yield f"data: {data}\n\n"
-            index = index + 1
+            if should_return_end_token(delta_text):
+                choice_data = ChatCompletionResponseStreamChoice(
+                    index=index,
+                    delta=DeltaMessage(role="assistant", content=delta_text),
+                    logprobs=None,
+                    finish_reason=None)
+                chunk = ChatCompletionStreamResponse(
+                    id=request_id,
+                    choices=[choice_data],
+                    model=model_name)
+                data = chunk.model_dump_json(exclude_unset=True)
+                yield f"data: {data}\n\n"
+                index = index + 1
         if remain == 0:
             choice_data = ChatCompletionResponseStreamChoice(
                 index=index,
@@ -146,18 +173,19 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
                 await asyncio.sleep(0)
                 continue
         if remain == 0 and delta_text is not None or remain != 0:
-            choice_data = CompletionResponseStreamChoice(
-                index=index,
-                text=delta_text,
-                logprobs=None,
-                finish_reason=None)
-            chunk = CompletionStreamResponse(
-                id=request_id,
-                choices=[choice_data],
-                model=model_name)
-            data = chunk.model_dump_json(exclude_unset=True)
-            yield f"data: {data}\n\n"
-            index = index + 1
+            if should_return_end_token(delta_text):
+                choice_data = CompletionResponseStreamChoice(
+                    index=index,
+                    text=delta_text,
+                    logprobs=None,
+                    finish_reason=None)
+                chunk = CompletionStreamResponse(
+                    id=request_id,
+                    choices=[choice_data],
+                    model=model_name)
+                data = chunk.model_dump_json(exclude_unset=True)
+                yield f"data: {data}\n\n"
+                index = index + 1
         if remain == 0:
             choice_data = CompletionResponseStreamChoice(
                 index=index,
@@ -237,31 +265,59 @@ async def generate_stream(inputs_request: InputsRequest):
 
 
 def get_prompt(messages) -> str:
-    prompt = ""
-    for msg in messages:
-        role = msg["role"]
-        content = msg["content"]
-        if role == "system":
-            prompt += f"<>\n{content}\n<>\n\n"
-        elif role == "user":
-            prompt += f"[INST] {content} [/INST] "
-        elif role == "assistant":
-            prompt += f"{content} "
+    if "codegeex" in local_model.model_name.lower():
+        query = messages[-1].content
+        if len(messages) <= 1:
+            history = []
         else:
-            invalidInputError(False, f"Unknown role: {role}")
-    return prompt.strip()
+            history = [msg.model_dump() for msg in messages[:-1]]
+        history.append({"role": "user", "content": query})
+        inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False,
+                                               return_tensors="pt", return_dict=False)
+        return inputs
+    else:
+        prompt = ""
+        for msg in messages:
+            role = msg.role
+            content = msg.content
+            if role == "system":
+                prompt += f"<>\n{content}\n<>\n\n"
+            elif role == "user":
+                prompt += f"[INST] {content} [/INST] "
+            elif role == "assistant":
+                prompt += f"{content} "
+            else:
+                invalidInputError(False, f"Unknown role: {role}")
+        return prompt.strip()
+
+
+def set_parameters(req):
+    if req.max_tokens is None:
+        n_predict = 256
+    else:
+        n_predict = req.max_tokens
+    if req.repetition_penalty is not None:
+        repetition_penalty = req.repetition_penalty
+    elif req.presence_penalty is not None:
+        repetition_penalty = req.presence_penalty
+    else:
+        repetition_penalty = None
+    if req.temperature is not None and req.temperature > 1e-4:
+        do_sample = True
+    else:
+        do_sample = False
+    return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens,
+                      top_p=req.top_p, repetition_penalty=repetition_penalty,
+                      temperature=req.temperature, top_k=req.top_k)
 
 
 @app.post("/v1/chat/completions")
 async def create_chat_completion(request: ChatCompletionRequest):
+    print(request)
     model_name = local_model.model_name
-    if request.max_tokens is None:
-        n_predict = 256
-    else:
-        n_predict = request.max_tokens
     inputs_request = InputsRequest(
         inputs=get_prompt(request.messages),
-        parameters=Parameters(max_new_tokens=n_predict),
+        parameters=set_parameters(request),
         stream=request.stream,
         req_type="chat"
     )
@@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
 @app.post("/v1/completions")
 async def create_completion(request: CompletionRequest):
     model_name = local_model.model_name
-    if request.max_tokens is None:
-        n_predict = 32
-    else:
-        n_predict = request.max_tokens
     inputs_request = InputsRequest(
         inputs=request.prompt,
-        parameters=Parameters(max_new_tokens=n_predict),
+        parameters=set_parameters(request),
         stream=request.stream,
         req_type="completion"
     )
diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
index c88e1434..099cfc0c 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
@@ -73,6 +73,11 @@ class ModelWorker:
 
             def model_generate():
                 generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
+                if "codegeex" in self.model_name.lower():
+                    eos_token_id = [tokenizer.eos_token_id,
+                                    tokenizer.convert_tokens_to_ids("<|user|>"),
+                                    tokenizer.convert_tokens_to_ids("<|observation|>")]
+                    generate_kwargs["eos_token_id"] = eos_token_id
                 self.model.generate(input_ids,
                                     streamer=self.streamer[request_id], **generate_kwargs)
                 torch.xpu.empty_cache()