Support codegeex4-9b for lightweight-serving (#11648)
* add options, support prompt and not return end_token * enable openai parameter * set do_sample None and update style
This commit is contained in:
parent
86fc0492f4
commit
23681fbf5c
3 changed files with 149 additions and 89 deletions
|
|
@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \
|
|||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Llama-2-7b-chat-hf",
|
||||
"messages": [{"role": "user", "content": "Hello! What is your name?"}]
|
||||
"messages": [{"role": "user", "content": "Hello! What is your name?"}],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
#### /v1/completions
|
||||
|
||||
```bash
|
||||
|
||||
curl http://localhost:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Llama-2-7b-chat-hf",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 32
|
||||
"max_tokens": 32,
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError
|
|||
import asyncio
|
||||
import uuid
|
||||
from typing import List, Optional, Union, Dict
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from .tgi_protocol import Parameters
|
||||
|
||||
|
||||
result_dict: Dict[str, str] = {}
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InputsRequest(BaseModel):
|
||||
inputs: str
|
||||
parameters: Optional[Parameters] = None
|
||||
stream: Optional[bool] = False
|
||||
req_type: str = 'completion'
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
max_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
global tokenizer
|
||||
global local_model
|
||||
|
||||
|
||||
class FastApp():
|
||||
def __init__(self, model, mytokenizer):
|
||||
global tokenizer
|
||||
global local_model
|
||||
local_model = model
|
||||
tokenizer = mytokenizer
|
||||
self.app = app
|
||||
|
||||
|
||||
from .openai_protocol import (
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
|
|
@ -80,6 +40,63 @@ from .openai_protocol import (
|
|||
CompletionStreamResponse,
|
||||
)
|
||||
|
||||
result_dict: Dict[str, str] = {}
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class InputsRequest(BaseModel):
|
||||
inputs: str
|
||||
parameters: Optional[Parameters] = None
|
||||
stream: Optional[bool] = False
|
||||
req_type: str = 'completion'
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
model: str
|
||||
max_tokens: Optional[int] = None
|
||||
min_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
max_tokens: Optional[int] = None
|
||||
min_tokens: Optional[int] = None
|
||||
stream: Optional[bool] = False
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
temperature: Optional[float] = None
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
global tokenizer
|
||||
global local_model
|
||||
|
||||
|
||||
class FastApp():
|
||||
def __init__(self, model, mytokenizer):
|
||||
global tokenizer
|
||||
global local_model
|
||||
local_model = model
|
||||
tokenizer = mytokenizer
|
||||
self.app = app
|
||||
|
||||
|
||||
def get_queue_next_token(delta_text_queue):
|
||||
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
|
||||
|
|
@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue):
|
|||
remain = 1
|
||||
return delta_text, remain
|
||||
|
||||
|
||||
def should_return_end_token(next_token):
|
||||
if "codegeex" not in local_model.model_name.lower():
|
||||
return True
|
||||
else:
|
||||
if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def chat_stream_generator(local_model, delta_text_queue, request_id):
|
||||
model_name = local_model.model_name
|
||||
index = 0
|
||||
|
|
@ -104,6 +130,7 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id):
|
|||
await asyncio.sleep(0)
|
||||
continue
|
||||
if remain == 0 and delta_text is not None or remain != 0:
|
||||
if should_return_end_token(delta_text):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(role="assistant", content=delta_text),
|
||||
|
|
@ -146,6 +173,7 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
|
|||
await asyncio.sleep(0)
|
||||
continue
|
||||
if remain == 0 and delta_text is not None or remain != 0:
|
||||
if should_return_end_token(delta_text):
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=index,
|
||||
text=delta_text,
|
||||
|
|
@ -237,10 +265,21 @@ async def generate_stream(inputs_request: InputsRequest):
|
|||
|
||||
|
||||
def get_prompt(messages) -> str:
|
||||
if "codegeex" in local_model.model_name.lower():
|
||||
query = messages[-1].content
|
||||
if len(messages) <= 1:
|
||||
history = []
|
||||
else:
|
||||
history = [msg.model_dump() for msg in messages[:-1]]
|
||||
history.append({"role": "user", "content": query})
|
||||
inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False,
|
||||
return_tensors="pt", return_dict=False)
|
||||
return inputs
|
||||
else:
|
||||
prompt = ""
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
if role == "system":
|
||||
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
|
||||
elif role == "user":
|
||||
|
|
@ -252,16 +291,33 @@ def get_prompt(messages) -> str:
|
|||
return prompt.strip()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
model_name = local_model.model_name
|
||||
if request.max_tokens is None:
|
||||
def set_parameters(req):
|
||||
if req.max_tokens is None:
|
||||
n_predict = 256
|
||||
else:
|
||||
n_predict = request.max_tokens
|
||||
n_predict = req.max_tokens
|
||||
if req.repetition_penalty is not None:
|
||||
repetition_penalty = req.repetition_penalty
|
||||
elif req.presence_penalty is not None:
|
||||
repetition_penalty = req.presence_penalty
|
||||
else:
|
||||
repetition_penalty = None
|
||||
if req.temperature is not None and req.temperature > 1e-4:
|
||||
do_sample = True
|
||||
else:
|
||||
do_sample = False
|
||||
return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens,
|
||||
top_p=req.top_p, repetition_penalty=repetition_penalty,
|
||||
temperature=req.temperature, top_k=req.top_k)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
print(request)
|
||||
model_name = local_model.model_name
|
||||
inputs_request = InputsRequest(
|
||||
inputs=get_prompt(request.messages),
|
||||
parameters=Parameters(max_new_tokens=n_predict),
|
||||
parameters=set_parameters(request),
|
||||
stream=request.stream,
|
||||
req_type="chat"
|
||||
)
|
||||
|
|
@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
model_name = local_model.model_name
|
||||
if request.max_tokens is None:
|
||||
n_predict = 32
|
||||
else:
|
||||
n_predict = request.max_tokens
|
||||
inputs_request = InputsRequest(
|
||||
inputs=request.prompt,
|
||||
parameters=Parameters(max_new_tokens=n_predict),
|
||||
parameters=set_parameters(request),
|
||||
stream=request.stream,
|
||||
req_type="completion"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -73,6 +73,11 @@ class ModelWorker:
|
|||
|
||||
def model_generate():
|
||||
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
|
||||
if "codegeex" in self.model_name.lower():
|
||||
eos_token_id = [tokenizer.eos_token_id,
|
||||
tokenizer.convert_tokens_to_ids("<|user|>"),
|
||||
tokenizer.convert_tokens_to_ids("<|observation|>")]
|
||||
generate_kwargs["eos_token_id"] = eos_token_id
|
||||
self.model.generate(input_ids,
|
||||
streamer=self.streamer[request_id], **generate_kwargs)
|
||||
torch.xpu.empty_cache()
|
||||
|
|
|
|||
Loading…
Reference in a new issue