Support codegeex4-9b for lightweight-serving (#11648)

* add options, support prompt and not return end_token

* enable openai parameter

* set do_sample None and update style
This commit is contained in:
Wang, Jian4 2024-07-26 09:41:03 +08:00 committed by GitHub
parent 86fc0492f4
commit 23681fbf5c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 149 additions and 89 deletions

View file

@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Llama-2-7b-chat-hf", "model": "Llama-2-7b-chat-hf",
"messages": [{"role": "user", "content": "Hello! What is your name?"}] "messages": [{"role": "user", "content": "Hello! What is your name?"}],
"stream": false
}' }'
``` ```
#### /v1/completions #### /v1/completions
```bash ```bash
curl http://localhost:8000/v1/completions \ curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Llama-2-7b-chat-hf", "model": "Llama-2-7b-chat-hf",
"prompt": "Once upon a time", "prompt": "Once upon a time",
"max_tokens": 32 "max_tokens": 32,
"stream": false
}' }'
``` ```

View file

@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError
import asyncio import asyncio
import uuid import uuid
from typing import List, Optional, Union, Dict from typing import List, Optional, Union, Dict
from fastapi.middleware.cors import CORSMiddleware
from .tgi_protocol import Parameters from .tgi_protocol import Parameters
result_dict: Dict[str, str] = {}
logger = logging.get_logger(__name__)
class InputsRequest(BaseModel):
inputs: str
parameters: Optional[Parameters] = None
stream: Optional[bool] = False
req_type: str = 'completion'
class ChatCompletionRequest(BaseModel):
messages: List[ChatCompletionMessageParam]
model: str
max_tokens: Optional[int] = None
stream: Optional[bool] = False
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
stream: Optional[bool] = False
app = FastAPI()
global tokenizer
global local_model
class FastApp():
def __init__(self, model, mytokenizer):
global tokenizer
global local_model
local_model = model
tokenizer = mytokenizer
self.app = app
from .openai_protocol import ( from .openai_protocol import (
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
@ -80,6 +40,63 @@ from .openai_protocol import (
CompletionStreamResponse, CompletionStreamResponse,
) )
result_dict: Dict[str, str] = {}
logger = logging.get_logger(__name__)
class InputsRequest(BaseModel):
inputs: str
parameters: Optional[Parameters] = None
stream: Optional[bool] = False
req_type: str = 'completion'
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
model: str
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = False
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
temperature: Optional[float] = None
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
min_tokens: Optional[int] = None
stream: Optional[bool] = False
top_p: Optional[float] = None
top_k: Optional[int] = None
repetition_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
temperature: Optional[float] = None
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
global tokenizer
global local_model
class FastApp():
def __init__(self, model, mytokenizer):
global tokenizer
global local_model
local_model = model
tokenizer = mytokenizer
self.app = app
def get_queue_next_token(delta_text_queue): def get_queue_next_token(delta_text_queue):
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60)) timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue):
remain = 1 remain = 1
return delta_text, remain return delta_text, remain
def should_return_end_token(next_token):
if "codegeex" not in local_model.model_name.lower():
return True
else:
if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]:
return False
return True
async def chat_stream_generator(local_model, delta_text_queue, request_id): async def chat_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name model_name = local_model.model_name
index = 0 index = 0
@ -104,18 +130,19 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id):
await asyncio.sleep(0) await asyncio.sleep(0)
continue continue
if remain == 0 and delta_text is not None or remain != 0: if remain == 0 and delta_text is not None or remain != 0:
choice_data = ChatCompletionResponseStreamChoice( if should_return_end_token(delta_text):
index=index, choice_data = ChatCompletionResponseStreamChoice(
delta=DeltaMessage(role="assistant", content=delta_text), index=index,
logprobs=None, delta=DeltaMessage(role="assistant", content=delta_text),
finish_reason=None) logprobs=None,
chunk = ChatCompletionStreamResponse( finish_reason=None)
id=request_id, chunk = ChatCompletionStreamResponse(
choices=[choice_data], id=request_id,
model=model_name) choices=[choice_data],
data = chunk.model_dump_json(exclude_unset=True) model=model_name)
yield f"data: {data}\n\n" data = chunk.model_dump_json(exclude_unset=True)
index = index + 1 yield f"data: {data}\n\n"
index = index + 1
if remain == 0: if remain == 0:
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=index, index=index,
@ -146,18 +173,19 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
await asyncio.sleep(0) await asyncio.sleep(0)
continue continue
if remain == 0 and delta_text is not None or remain != 0: if remain == 0 and delta_text is not None or remain != 0:
choice_data = CompletionResponseStreamChoice( if should_return_end_token(delta_text):
index=index, choice_data = CompletionResponseStreamChoice(
text=delta_text, index=index,
logprobs=None, text=delta_text,
finish_reason=None) logprobs=None,
chunk = CompletionStreamResponse( finish_reason=None)
id=request_id, chunk = CompletionStreamResponse(
choices=[choice_data], id=request_id,
model=model_name) choices=[choice_data],
data = chunk.model_dump_json(exclude_unset=True) model=model_name)
yield f"data: {data}\n\n" data = chunk.model_dump_json(exclude_unset=True)
index = index + 1 yield f"data: {data}\n\n"
index = index + 1
if remain == 0: if remain == 0:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
@ -237,31 +265,59 @@ async def generate_stream(inputs_request: InputsRequest):
def get_prompt(messages) -> str: def get_prompt(messages) -> str:
prompt = "" if "codegeex" in local_model.model_name.lower():
for msg in messages: query = messages[-1].content
role = msg["role"] if len(messages) <= 1:
content = msg["content"] history = []
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else: else:
invalidInputError(False, f"Unknown role: {role}") history = [msg.model_dump() for msg in messages[:-1]]
return prompt.strip() history.append({"role": "user", "content": query})
inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False,
return_tensors="pt", return_dict=False)
return inputs
else:
prompt = ""
for msg in messages:
role = msg.role
content = msg.content
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else:
invalidInputError(False, f"Unknown role: {role}")
return prompt.strip()
def set_parameters(req):
if req.max_tokens is None:
n_predict = 256
else:
n_predict = req.max_tokens
if req.repetition_penalty is not None:
repetition_penalty = req.repetition_penalty
elif req.presence_penalty is not None:
repetition_penalty = req.presence_penalty
else:
repetition_penalty = None
if req.temperature is not None and req.temperature > 1e-4:
do_sample = True
else:
do_sample = False
return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens,
top_p=req.top_p, repetition_penalty=repetition_penalty,
temperature=req.temperature, top_k=req.top_k)
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
print(request)
model_name = local_model.model_name model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
inputs_request = InputsRequest( inputs_request = InputsRequest(
inputs=get_prompt(request.messages), inputs=get_prompt(request.messages),
parameters=Parameters(max_new_tokens=n_predict), parameters=set_parameters(request),
stream=request.stream, stream=request.stream,
req_type="chat" req_type="chat"
) )
@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
@app.post("/v1/completions") @app.post("/v1/completions")
async def create_completion(request: CompletionRequest): async def create_completion(request: CompletionRequest):
model_name = local_model.model_name model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 32
else:
n_predict = request.max_tokens
inputs_request = InputsRequest( inputs_request = InputsRequest(
inputs=request.prompt, inputs=request.prompt,
parameters=Parameters(max_new_tokens=n_predict), parameters=set_parameters(request),
stream=request.stream, stream=request.stream,
req_type="completion" req_type="completion"
) )

View file

@ -73,6 +73,11 @@ class ModelWorker:
def model_generate(): def model_generate():
generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
if "codegeex" in self.model_name.lower():
eos_token_id = [tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|user|>"),
tokenizer.convert_tokens_to_ids("<|observation|>")]
generate_kwargs["eos_token_id"] = eos_token_id
self.model.generate(input_ids, self.model.generate(input_ids,
streamer=self.streamer[request_id], **generate_kwargs) streamer=self.streamer[request_id], **generate_kwargs)
torch.xpu.empty_cache() torch.xpu.empty_cache()