diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md
index 791bf166..104032c6 100644
--- a/python/llm/example/GPU/Lightweight-Serving/README.md
+++ b/python/llm/example/GPU/Lightweight-Serving/README.md
@@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Llama-2-7b-chat-hf",
- "messages": [{"role": "user", "content": "Hello! What is your name?"}]
+ "messages": [{"role": "user", "content": "Hello! What is your name?"}],
+ "stream": false
}'
```
#### /v1/completions
```bash
+
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Llama-2-7b-chat-hf",
"prompt": "Once upon a time",
- "max_tokens": 32
+ "max_tokens": 32,
+ "stream": false
}'
```
diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
index 75bd49d0..5109c822 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
@@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError
import asyncio
import uuid
from typing import List, Optional, Union, Dict
+from fastapi.middleware.cors import CORSMiddleware
from .tgi_protocol import Parameters
-
-
-result_dict: Dict[str, str] = {}
-logger = logging.get_logger(__name__)
-
-
-class InputsRequest(BaseModel):
- inputs: str
- parameters: Optional[Parameters] = None
- stream: Optional[bool] = False
- req_type: str = 'completion'
-
-
-class ChatCompletionRequest(BaseModel):
- messages: List[ChatCompletionMessageParam]
- model: str
- max_tokens: Optional[int] = None
- stream: Optional[bool] = False
-
-
-class CompletionRequest(BaseModel):
- model: str
- prompt: Union[List[int], List[List[int]], str, List[str]]
- max_tokens: Optional[int] = None
- stream: Optional[bool] = False
-
-
-app = FastAPI()
-global tokenizer
-global local_model
-
-
-class FastApp():
- def __init__(self, model, mytokenizer):
- global tokenizer
- global local_model
- local_model = model
- tokenizer = mytokenizer
- self.app = app
-
-
from .openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
@@ -80,6 +40,63 @@ from .openai_protocol import (
CompletionStreamResponse,
)
+result_dict: Dict[str, str] = {}
+logger = logging.get_logger(__name__)
+
+
+class InputsRequest(BaseModel):
+ inputs: str
+ parameters: Optional[Parameters] = None
+ stream: Optional[bool] = False
+ req_type: str = 'completion'
+
+
+class ChatCompletionRequest(BaseModel):
+ messages: List[ChatMessage]
+ model: str
+ max_tokens: Optional[int] = None
+ min_tokens: Optional[int] = None
+ stream: Optional[bool] = False
+ top_p: Optional[float] = None
+ top_k: Optional[int] = None
+ repetition_penalty: Optional[float] = None
+ presence_penalty: Optional[float] = None
+ temperature: Optional[float] = None
+
+
+class CompletionRequest(BaseModel):
+ model: str
+ prompt: Union[List[int], List[List[int]], str, List[str]]
+ max_tokens: Optional[int] = None
+ min_tokens: Optional[int] = None
+ stream: Optional[bool] = False
+ top_p: Optional[float] = None
+ top_k: Optional[int] = None
+ repetition_penalty: Optional[float] = None
+ presence_penalty: Optional[float] = None
+ temperature: Optional[float] = None
+
+
+app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+global tokenizer
+global local_model
+
+
+class FastApp():
+ def __init__(self, model, mytokenizer):
+ global tokenizer
+ global local_model
+ local_model = model
+ tokenizer = mytokenizer
+ self.app = app
+
def get_queue_next_token(delta_text_queue):
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
@@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue):
remain = 1
return delta_text, remain
+
+def should_return_end_token(next_token):
+ if "codegeex" not in local_model.model_name.lower():
+ return True
+ else:
+ if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]:
+ return False
+ return True
+
async def chat_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
@@ -104,18 +130,19 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id):
await asyncio.sleep(0)
continue
if remain == 0 and delta_text is not None or remain != 0:
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(role="assistant", content=delta_text),
- logprobs=None,
- finish_reason=None)
- chunk = ChatCompletionStreamResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- index = index + 1
+ if should_return_end_token(delta_text):
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(role="assistant", content=delta_text),
+ logprobs=None,
+ finish_reason=None)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ index = index + 1
if remain == 0:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
@@ -146,18 +173,19 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
await asyncio.sleep(0)
continue
if remain == 0 and delta_text is not None or remain != 0:
- choice_data = CompletionResponseStreamChoice(
- index=index,
- text=delta_text,
- logprobs=None,
- finish_reason=None)
- chunk = CompletionStreamResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- index = index + 1
+ if should_return_end_token(delta_text):
+ choice_data = CompletionResponseStreamChoice(
+ index=index,
+ text=delta_text,
+ logprobs=None,
+ finish_reason=None)
+ chunk = CompletionStreamResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ index = index + 1
if remain == 0:
choice_data = CompletionResponseStreamChoice(
index=index,
@@ -237,31 +265,59 @@ async def generate_stream(inputs_request: InputsRequest):
def get_prompt(messages) -> str:
- prompt = ""
- for msg in messages:
- role = msg["role"]
- content = msg["content"]
- if role == "system":
- prompt += f"<>\n{content}\n<>\n\n"
- elif role == "user":
- prompt += f"[INST] {content} [/INST] "
- elif role == "assistant":
- prompt += f"{content} "
+ if "codegeex" in local_model.model_name.lower():
+ query = messages[-1].content
+ if len(messages) <= 1:
+ history = []
else:
- invalidInputError(False, f"Unknown role: {role}")
- return prompt.strip()
+ history = [msg.model_dump() for msg in messages[:-1]]
+ history.append({"role": "user", "content": query})
+ inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False,
+ return_tensors="pt", return_dict=False)
+ return inputs
+ else:
+ prompt = ""
+ for msg in messages:
+ role = msg.role
+ content = msg.content
+ if role == "system":
+ prompt += f"<>\n{content}\n<>\n\n"
+ elif role == "user":
+ prompt += f"[INST] {content} [/INST] "
+ elif role == "assistant":
+ prompt += f"{content} "
+ else:
+ invalidInputError(False, f"Unknown role: {role}")
+ return prompt.strip()
+
+
+def set_parameters(req):
+ if req.max_tokens is None:
+ n_predict = 256
+ else:
+ n_predict = req.max_tokens
+ if req.repetition_penalty is not None:
+ repetition_penalty = req.repetition_penalty
+ elif req.presence_penalty is not None:
+ repetition_penalty = req.presence_penalty
+ else:
+ repetition_penalty = None
+ if req.temperature is not None and req.temperature > 1e-4:
+ do_sample = True
+ else:
+ do_sample = False
+ return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens,
+ top_p=req.top_p, repetition_penalty=repetition_penalty,
+ temperature=req.temperature, top_k=req.top_k)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
+ print(request)
model_name = local_model.model_name
- if request.max_tokens is None:
- n_predict = 256
- else:
- n_predict = request.max_tokens
inputs_request = InputsRequest(
inputs=get_prompt(request.messages),
- parameters=Parameters(max_new_tokens=n_predict),
+ parameters=set_parameters(request),
stream=request.stream,
req_type="chat"
)
@@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
model_name = local_model.model_name
- if request.max_tokens is None:
- n_predict = 32
- else:
- n_predict = request.max_tokens
inputs_request = InputsRequest(
inputs=request.prompt,
- parameters=Parameters(max_new_tokens=n_predict),
+ parameters=set_parameters(request),
stream=request.stream,
req_type="completion"
)
diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
index c88e1434..099cfc0c 100644
--- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
@@ -73,6 +73,11 @@ class ModelWorker:
def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
+ if "codegeex" in self.model_name.lower():
+ eos_token_id = [tokenizer.eos_token_id,
+ tokenizer.convert_tokens_to_ids("<|user|>"),
+ tokenizer.convert_tokens_to_ids("<|observation|>")]
+ generate_kwargs["eos_token_id"] = eos_token_id
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()