Support codegeex4-9b for lightweight-serving (#11648)

* add options, support prompt and not return end_token

* enable openai parameter

* set do_sample None and update style
This commit is contained in:
Wang, Jian4 2024-07-26 09:41:03 +08:00 committed by GitHub
parent 86fc0492f4
commit 23681fbf5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 149 additions and 89 deletions

View file

@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Llama-2-7b-chat-hf",
"messages": [{"role": "user", "content": "Hello! What is your name?"}]
"messages": [{"role": "user", "content": "Hello! What is your name?"}],
"stream": false
}'
```
#### /v1/completions
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Llama-2-7b-chat-hf",
"prompt": "Once upon a time",
"max_tokens": 32
"max_tokens": 32,
"stream": false
}'
```

View file

@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError
import asyncio
import uuid
from typing import List, Optional, Union, Dict
from fastapi.middleware.cors import CORSMiddleware
from .tgi_protocol import Parameters
result_dict: Dict[str, str] = {}
logger = logging.get_logger(__name__)
class InputsRequest(BaseModel):
inputs: str
parameters: Optional[Parameters] = None
stream: Optional[bool] = False
req_type: str = 'completion'
class ChatCompletionRequest(BaseModel):
messages: List[ChatCompletionMessageParam]
model: str
max_tokens: Optional[int] = None
stream: Optional[bool] = False
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
stream: Optional[bool] = False
app = FastAPI()
global tokenizer
global local_model
class FastApp():
def __init__(self, model, mytokenizer):
global tokenizer
global local_model
local_model = model
tokenizer = mytokenizer
self.app = app
from .openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
@ -80,6 +40,63 @@ from .openai_protocol import (
CompletionStreamResponse,
)
result_dict: Dict[str, str] = {}
logger = logging.get_logger(__name__)
class InputsRequest(BaseModel):
inputs: str
parameters: Optional[Parameters] = None
stream: Optional[bool] = False
req_type: str = 'completion'
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
model: str
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = False
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
temperature: Optional[float] = None
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = False
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
temperature: Optional[float] = None
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
global tokenizer
global local_model
class FastApp():
def __init__(self, model, mytokenizer):
global tokenizer
global local_model
local_model = model
tokenizer = mytokenizer
self.app = app
def get_queue_next_token(delta_text_queue):
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue):
remain = 1
return delta_text, remain
def should_return_end_token(next_token):
if "codegeex" not in local_model.model_name.lower():
return True
else:
if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]:
return False
return True
async def chat_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
@ -104,6 +130,7 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id):
await asyncio.sleep(0)
continue
if remain == 0 and delta_text is not None or remain != 0:
if should_return_end_token(delta_text):
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=delta_text),
@ -146,6 +173,7 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
await asyncio.sleep(0)
continue
if remain == 0 and delta_text is not None or remain != 0:
if should_return_end_token(delta_text):
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta_text,
@ -237,10 +265,21 @@ async def generate_stream(inputs_request: InputsRequest):
def get_prompt(messages) -> str:
if "codegeex" in local_model.model_name.lower():
query = messages[-1].content
if len(messages) <= 1:
history = []
else:
history = [msg.model_dump() for msg in messages[:-1]]
history.append({"role": "user", "content": query})
inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False,
return_tensors="pt", return_dict=False)
return inputs
else:
prompt = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
role = msg.role
content = msg.content
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
@ -252,16 +291,33 @@ def get_prompt(messages) -> str:
return prompt.strip()
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
def set_parameters(req):
if req.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
n_predict = req.max_tokens
if req.repetition_penalty is not None:
repetition_penalty = req.repetition_penalty
elif req.presence_penalty is not None:
repetition_penalty = req.presence_penalty
else:
repetition_penalty = None
if req.temperature is not None and req.temperature > 1e-4:
do_sample = True
else:
do_sample = False
return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens,
top_p=req.top_p, repetition_penalty=repetition_penalty,
temperature=req.temperature, top_k=req.top_k)
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
print(request)
model_name = local_model.model_name
inputs_request = InputsRequest(
inputs=get_prompt(request.messages),
parameters=Parameters(max_new_tokens=n_predict),
parameters=set_parameters(request),
stream=request.stream,
req_type="chat"
)
@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 32
else:
n_predict = request.max_tokens
inputs_request = InputsRequest(
inputs=request.prompt,
parameters=Parameters(max_new_tokens=n_predict),
parameters=set_parameters(request),
stream=request.stream,
req_type="completion"
)

View file

@ -73,6 +73,11 @@ class ModelWorker:
def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
if "codegeex" in self.model_name.lower():
eos_token_id = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache()