Support codegeex4-9b for lightweight-serving (#11648)
* add options, support prompt and not return end_token * enable openai parameter * set do_sample None and update style
This commit is contained in:
		
							parent
							
								
									86fc0492f4
								
							
						
					
					
						commit
						23681fbf5c
					
				
					 3 changed files with 149 additions and 89 deletions
				
			
		| 
						 | 
					@ -167,19 +167,22 @@ curl http://localhost:8000/v1/chat/completions \
 | 
				
			||||||
  -H "Content-Type: application/json" \
 | 
					  -H "Content-Type: application/json" \
 | 
				
			||||||
  -d '{
 | 
					  -d '{
 | 
				
			||||||
    "model": "Llama-2-7b-chat-hf",
 | 
					    "model": "Llama-2-7b-chat-hf",
 | 
				
			||||||
    "messages": [{"role": "user", "content": "Hello! What is your name?"}]
 | 
					    "messages": [{"role": "user", "content": "Hello! What is your name?"}],
 | 
				
			||||||
 | 
					    "stream": false
 | 
				
			||||||
  }'
 | 
					  }'
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#### /v1/completions
 | 
					#### /v1/completions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
curl http://localhost:8000/v1/completions \
 | 
					curl http://localhost:8000/v1/completions \
 | 
				
			||||||
  -H "Content-Type: application/json" \
 | 
					  -H "Content-Type: application/json" \
 | 
				
			||||||
  -d '{
 | 
					  -d '{
 | 
				
			||||||
    "model": "Llama-2-7b-chat-hf",
 | 
					    "model": "Llama-2-7b-chat-hf",
 | 
				
			||||||
    "prompt": "Once upon a time",
 | 
					    "prompt": "Once upon a time",
 | 
				
			||||||
    "max_tokens": 32
 | 
					    "max_tokens": 32,
 | 
				
			||||||
 | 
					    "stream": false
 | 
				
			||||||
  }'
 | 
					  }'
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -25,48 +25,8 @@ from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
from typing import List, Optional, Union, Dict
 | 
					from typing import List, Optional, Union, Dict
 | 
				
			||||||
 | 
					from fastapi.middleware.cors import CORSMiddleware
 | 
				
			||||||
from .tgi_protocol import Parameters
 | 
					from .tgi_protocol import Parameters
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
result_dict: Dict[str, str] = {}
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class InputsRequest(BaseModel):
 | 
					 | 
				
			||||||
    inputs: str
 | 
					 | 
				
			||||||
    parameters: Optional[Parameters] = None
 | 
					 | 
				
			||||||
    stream: Optional[bool] = False
 | 
					 | 
				
			||||||
    req_type: str = 'completion'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ChatCompletionRequest(BaseModel):
 | 
					 | 
				
			||||||
    messages: List[ChatCompletionMessageParam]
 | 
					 | 
				
			||||||
    model: str
 | 
					 | 
				
			||||||
    max_tokens: Optional[int] = None
 | 
					 | 
				
			||||||
    stream: Optional[bool] = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class CompletionRequest(BaseModel):
 | 
					 | 
				
			||||||
    model: str
 | 
					 | 
				
			||||||
    prompt: Union[List[int], List[List[int]], str, List[str]]
 | 
					 | 
				
			||||||
    max_tokens: Optional[int] = None
 | 
					 | 
				
			||||||
    stream: Optional[bool] = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
app = FastAPI()
 | 
					 | 
				
			||||||
global tokenizer
 | 
					 | 
				
			||||||
global local_model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FastApp():
 | 
					 | 
				
			||||||
    def __init__(self, model, mytokenizer):
 | 
					 | 
				
			||||||
        global tokenizer
 | 
					 | 
				
			||||||
        global local_model
 | 
					 | 
				
			||||||
        local_model = model
 | 
					 | 
				
			||||||
        tokenizer = mytokenizer
 | 
					 | 
				
			||||||
        self.app = app
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from .openai_protocol import (
 | 
					from .openai_protocol import (
 | 
				
			||||||
    ChatCompletionResponseStreamChoice,
 | 
					    ChatCompletionResponseStreamChoice,
 | 
				
			||||||
    ChatCompletionStreamResponse,
 | 
					    ChatCompletionStreamResponse,
 | 
				
			||||||
| 
						 | 
					@ -80,6 +40,63 @@ from .openai_protocol import (
 | 
				
			||||||
    CompletionStreamResponse,
 | 
					    CompletionStreamResponse,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					result_dict: Dict[str, str] = {}
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class InputsRequest(BaseModel):
 | 
				
			||||||
 | 
					    inputs: str
 | 
				
			||||||
 | 
					    parameters: Optional[Parameters] = None
 | 
				
			||||||
 | 
					    stream: Optional[bool] = False
 | 
				
			||||||
 | 
					    req_type: str = 'completion'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChatCompletionRequest(BaseModel):
 | 
				
			||||||
 | 
					    messages: List[ChatMessage]
 | 
				
			||||||
 | 
					    model: str
 | 
				
			||||||
 | 
					    max_tokens: Optional[int] = None
 | 
				
			||||||
 | 
					    min_tokens: Optional[int] = None
 | 
				
			||||||
 | 
					    stream: Optional[bool] = False
 | 
				
			||||||
 | 
					    top_p: Optional[float] = None
 | 
				
			||||||
 | 
					    top_k: Optional[int] = None
 | 
				
			||||||
 | 
					    repetition_penalty: Optional[float] = None
 | 
				
			||||||
 | 
					    presence_penalty: Optional[float] = None
 | 
				
			||||||
 | 
					    temperature: Optional[float] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CompletionRequest(BaseModel):
 | 
				
			||||||
 | 
					    model: str
 | 
				
			||||||
 | 
					    prompt: Union[List[int], List[List[int]], str, List[str]]
 | 
				
			||||||
 | 
					    max_tokens: Optional[int] = None
 | 
				
			||||||
 | 
					    min_tokens: Optional[int] = None
 | 
				
			||||||
 | 
					    stream: Optional[bool] = False
 | 
				
			||||||
 | 
					    top_p: Optional[float] = None
 | 
				
			||||||
 | 
					    top_k: Optional[int] = None
 | 
				
			||||||
 | 
					    repetition_penalty: Optional[float] = None
 | 
				
			||||||
 | 
					    presence_penalty: Optional[float] = None
 | 
				
			||||||
 | 
					    temperature: Optional[float] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					app = FastAPI()
 | 
				
			||||||
 | 
					app.add_middleware(
 | 
				
			||||||
 | 
					    CORSMiddleware,
 | 
				
			||||||
 | 
					    allow_origins=["*"],
 | 
				
			||||||
 | 
					    allow_methods=["*"],
 | 
				
			||||||
 | 
					    allow_headers=["*"],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					global tokenizer
 | 
				
			||||||
 | 
					global local_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FastApp():
 | 
				
			||||||
 | 
					    def __init__(self, model, mytokenizer):
 | 
				
			||||||
 | 
					        global tokenizer
 | 
				
			||||||
 | 
					        global local_model
 | 
				
			||||||
 | 
					        local_model = model
 | 
				
			||||||
 | 
					        tokenizer = mytokenizer
 | 
				
			||||||
 | 
					        self.app = app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_queue_next_token(delta_text_queue):
 | 
					def get_queue_next_token(delta_text_queue):
 | 
				
			||||||
    timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
 | 
					    timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
 | 
				
			||||||
| 
						 | 
					@ -90,6 +107,15 @@ def get_queue_next_token(delta_text_queue):
 | 
				
			||||||
        remain = 1
 | 
					        remain = 1
 | 
				
			||||||
    return delta_text, remain
 | 
					    return delta_text, remain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def should_return_end_token(next_token):
 | 
				
			||||||
 | 
					    if "codegeex" not in local_model.model_name.lower():
 | 
				
			||||||
 | 
					        return True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if next_token in ["<|user|>", "<|endoftext|>", "<|observation|>"]:
 | 
				
			||||||
 | 
					            return False
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def chat_stream_generator(local_model, delta_text_queue, request_id):
 | 
					async def chat_stream_generator(local_model, delta_text_queue, request_id):
 | 
				
			||||||
    model_name = local_model.model_name
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
    index = 0
 | 
					    index = 0
 | 
				
			||||||
| 
						 | 
					@ -104,6 +130,7 @@ async def chat_stream_generator(local_model, delta_text_queue, request_id):
 | 
				
			||||||
                await asyncio.sleep(0)
 | 
					                await asyncio.sleep(0)
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
        if remain == 0 and delta_text is not None or remain != 0:
 | 
					        if remain == 0 and delta_text is not None or remain != 0:
 | 
				
			||||||
 | 
					            if should_return_end_token(delta_text):
 | 
				
			||||||
                choice_data = ChatCompletionResponseStreamChoice(
 | 
					                choice_data = ChatCompletionResponseStreamChoice(
 | 
				
			||||||
                    index=index,
 | 
					                    index=index,
 | 
				
			||||||
                    delta=DeltaMessage(role="assistant", content=delta_text),
 | 
					                    delta=DeltaMessage(role="assistant", content=delta_text),
 | 
				
			||||||
| 
						 | 
					@ -146,6 +173,7 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
 | 
				
			||||||
                await asyncio.sleep(0)
 | 
					                await asyncio.sleep(0)
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
        if remain == 0 and delta_text is not None or remain != 0:
 | 
					        if remain == 0 and delta_text is not None or remain != 0:
 | 
				
			||||||
 | 
					            if should_return_end_token(delta_text):
 | 
				
			||||||
                choice_data = CompletionResponseStreamChoice(
 | 
					                choice_data = CompletionResponseStreamChoice(
 | 
				
			||||||
                    index=index,
 | 
					                    index=index,
 | 
				
			||||||
                    text=delta_text,
 | 
					                    text=delta_text,
 | 
				
			||||||
| 
						 | 
					@ -237,10 +265,21 @@ async def generate_stream(inputs_request: InputsRequest):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_prompt(messages) -> str:
 | 
					def get_prompt(messages) -> str:
 | 
				
			||||||
 | 
					    if "codegeex" in local_model.model_name.lower():
 | 
				
			||||||
 | 
					        query = messages[-1].content
 | 
				
			||||||
 | 
					        if len(messages) <= 1:
 | 
				
			||||||
 | 
					            history = []
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            history = [msg.model_dump() for msg in messages[:-1]]
 | 
				
			||||||
 | 
					        history.append({"role": "user", "content": query})
 | 
				
			||||||
 | 
					        inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=False,
 | 
				
			||||||
 | 
					                                               return_tensors="pt", return_dict=False)
 | 
				
			||||||
 | 
					        return inputs
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
        prompt = ""
 | 
					        prompt = ""
 | 
				
			||||||
        for msg in messages:
 | 
					        for msg in messages:
 | 
				
			||||||
        role = msg["role"]
 | 
					            role = msg.role
 | 
				
			||||||
        content = msg["content"]
 | 
					            content = msg.content
 | 
				
			||||||
            if role == "system":
 | 
					            if role == "system":
 | 
				
			||||||
                prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
 | 
					                prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
 | 
				
			||||||
            elif role == "user":
 | 
					            elif role == "user":
 | 
				
			||||||
| 
						 | 
					@ -252,16 +291,33 @@ def get_prompt(messages) -> str:
 | 
				
			||||||
        return prompt.strip()
 | 
					        return prompt.strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.post("/v1/chat/completions")
 | 
					def set_parameters(req):
 | 
				
			||||||
async def create_chat_completion(request: ChatCompletionRequest):
 | 
					    if req.max_tokens is None:
 | 
				
			||||||
    model_name = local_model.model_name
 | 
					 | 
				
			||||||
    if request.max_tokens is None:
 | 
					 | 
				
			||||||
        n_predict = 256
 | 
					        n_predict = 256
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        n_predict = request.max_tokens
 | 
					        n_predict = req.max_tokens
 | 
				
			||||||
 | 
					    if req.repetition_penalty is not None:
 | 
				
			||||||
 | 
					        repetition_penalty = req.repetition_penalty
 | 
				
			||||||
 | 
					    elif req.presence_penalty is not None:
 | 
				
			||||||
 | 
					        repetition_penalty = req.presence_penalty
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        repetition_penalty = None
 | 
				
			||||||
 | 
					    if req.temperature is not None and req.temperature > 1e-4:
 | 
				
			||||||
 | 
					        do_sample = True
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        do_sample = False
 | 
				
			||||||
 | 
					    return Parameters(max_new_tokens=n_predict, do_sample=do_sample, min_new_tokens=req.min_tokens,
 | 
				
			||||||
 | 
					                      top_p=req.top_p, repetition_penalty=repetition_penalty,
 | 
				
			||||||
 | 
					                      temperature=req.temperature, top_k=req.top_k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/v1/chat/completions")
 | 
				
			||||||
 | 
					async def create_chat_completion(request: ChatCompletionRequest):
 | 
				
			||||||
 | 
					    print(request)
 | 
				
			||||||
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
    inputs_request = InputsRequest(
 | 
					    inputs_request = InputsRequest(
 | 
				
			||||||
        inputs=get_prompt(request.messages),
 | 
					        inputs=get_prompt(request.messages),
 | 
				
			||||||
        parameters=Parameters(max_new_tokens=n_predict),
 | 
					        parameters=set_parameters(request),
 | 
				
			||||||
        stream=request.stream,
 | 
					        stream=request.stream,
 | 
				
			||||||
        req_type="chat"
 | 
					        req_type="chat"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -284,13 +340,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
 | 
				
			||||||
@app.post("/v1/completions")
 | 
					@app.post("/v1/completions")
 | 
				
			||||||
async def create_completion(request: CompletionRequest):
 | 
					async def create_completion(request: CompletionRequest):
 | 
				
			||||||
    model_name = local_model.model_name
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
    if request.max_tokens is None:
 | 
					 | 
				
			||||||
        n_predict = 32
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        n_predict = request.max_tokens
 | 
					 | 
				
			||||||
    inputs_request = InputsRequest(
 | 
					    inputs_request = InputsRequest(
 | 
				
			||||||
        inputs=request.prompt,
 | 
					        inputs=request.prompt,
 | 
				
			||||||
        parameters=Parameters(max_new_tokens=n_predict),
 | 
					        parameters=set_parameters(request),
 | 
				
			||||||
        stream=request.stream,
 | 
					        stream=request.stream,
 | 
				
			||||||
        req_type="completion"
 | 
					        req_type="completion"
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,6 +73,11 @@ class ModelWorker:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            def model_generate():
 | 
					            def model_generate():
 | 
				
			||||||
                generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
 | 
					                generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None}
 | 
				
			||||||
 | 
					                if "codegeex" in self.model_name.lower():
 | 
				
			||||||
 | 
					                    eos_token_id = [tokenizer.eos_token_id,
 | 
				
			||||||
 | 
					                                    tokenizer.convert_tokens_to_ids("<|user|>"),
 | 
				
			||||||
 | 
					                                    tokenizer.convert_tokens_to_ids("<|observation|>")]
 | 
				
			||||||
 | 
					                    generate_kwargs["eos_token_id"] = eos_token_id
 | 
				
			||||||
                self.model.generate(input_ids,
 | 
					                self.model.generate(input_ids,
 | 
				
			||||||
                                    streamer=self.streamer[request_id], **generate_kwargs)
 | 
					                                    streamer=self.streamer[request_id], **generate_kwargs)
 | 
				
			||||||
                torch.xpu.empty_cache()
 | 
					                torch.xpu.empty_cache()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue