Merge tgi_api_server to main (#11036)
* init * fix style * speculative can not use benchmark * add tgi server readme
This commit is contained in:
parent
f60565adc7
commit
a2e1578fd9
5 changed files with 1248 additions and 1 deletions
|
|
@ -124,6 +124,139 @@ This is the user interface that users will interact with.
|
||||||
|
|
||||||
By following these steps, you will be able to serve your models using the web UI with IPEX-LLM as the backend. You can open your browser and chat with a model now.
|
By following these steps, you will be able to serve your models using the web UI with IPEX-LLM as the backend. You can open your browser and chat with a model now.
|
||||||
|
|
||||||
|
### Launch TGI Style API server
|
||||||
|
|
||||||
|
When you have started the controller and the worker, you can start TGI Style API server as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m ipex_llm.serving.fastchat.tgi_api_server --host localhost --port 8000
|
||||||
|
```
|
||||||
|
You can use `curl` for observing the output of the api
|
||||||
|
|
||||||
|
#### Using /generate API
|
||||||
|
|
||||||
|
This is to send a sentence as inputs in the request, and is expected to receive a response containing model-generated answer.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST -H "Content-Type: application/json" -d '{
|
||||||
|
"inputs": "What is AI?",
|
||||||
|
"parameters": {
|
||||||
|
"best_of": 1,
|
||||||
|
"decoder_input_details": true,
|
||||||
|
"details": true,
|
||||||
|
"do_sample": true,
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": "string"
|
||||||
|
},
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"repetition_penalty": 1.03,
|
||||||
|
"return_full_text": false,
|
||||||
|
"seed": 0.1,
|
||||||
|
"stop": [
|
||||||
|
"photographer"
|
||||||
|
],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_n_tokens": 5,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"truncate": true,
|
||||||
|
"typical_p": 0.95,
|
||||||
|
"watermark": true
|
||||||
|
}
|
||||||
|
}' http://localhost:8000/generate
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample output:
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer "
|
||||||
|
},
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ",
|
||||||
|
"generated_tokens": 31
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 4,
|
||||||
|
"total_tokens": 35,
|
||||||
|
"completion_tokens": 31
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using /generate_stream API
|
||||||
|
|
||||||
|
This is to send a sentence as inputs in the request, and a long connection will be opened to continuously receive multiple responses containing model-generated answer.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST -H "Content-Type: application/json" -d '{
|
||||||
|
"inputs": "What is AI?",
|
||||||
|
"parameters": {
|
||||||
|
"best_of": 1,
|
||||||
|
"decoder_input_details": true,
|
||||||
|
"details": true,
|
||||||
|
"do_sample": true,
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": "string"
|
||||||
|
},
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"repetition_penalty": 1.03,
|
||||||
|
"return_full_text": false,
|
||||||
|
"seed": 0.1,
|
||||||
|
"stop": [
|
||||||
|
"photographer"
|
||||||
|
],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_n_tokens": 5,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"truncate": true,
|
||||||
|
"typical_p": 0.95,
|
||||||
|
"watermark": true
|
||||||
|
}
|
||||||
|
}' http://localhost:8000/generate_stream
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample output:
|
||||||
|
```bash
|
||||||
|
data: {"token": {"id": 663359, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 300560, "text": "\n", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 725120, "text": "Artificial Intelligence ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 734609, "text": "(AI) is ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 362235, "text": "a branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 380983, "text": "science that attempts to ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 249979, "text": "simulate the way that ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 972663, "text": "the human brain ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 793301, "text": "works. It is a ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 501380, "text": "branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 673232, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 2, "text": "</s>", "logprob": 0.0, "special": true}, "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", "details": {"finish_reason": "eos_token", "generated_tokens": 31, "prefill_tokens": 4, "seed": 2023}, "special_ret": {"tensor": []}}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Launch RESTful API server
|
### Launch RESTful API server
|
||||||
|
|
||||||
To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it.
|
To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it.
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,140 @@ This is the user interface that users will interact with.
|
||||||
|
|
||||||
By following these steps, you will be able to serve your models using the web UI with IPEX-LLM as the backend. You can open your browser and chat with a model now.
|
By following these steps, you will be able to serve your models using the web UI with IPEX-LLM as the backend. You can open your browser and chat with a model now.
|
||||||
|
|
||||||
|
|
||||||
|
### Launch TGI Style API server
|
||||||
|
|
||||||
|
When you have started the controller and the worker, you can start TGI Style API server as follows:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m ipex_llm.serving.fastchat.tgi_api_server --host localhost --port 8000
|
||||||
|
```
|
||||||
|
You can use `curl` for observing the output of the api
|
||||||
|
|
||||||
|
#### Using /generate API
|
||||||
|
|
||||||
|
This is to send a sentence as inputs in the request, and is expected to receive a response containing model-generated answer.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST -H "Content-Type: application/json" -d '{
|
||||||
|
"inputs": "What is AI?",
|
||||||
|
"parameters": {
|
||||||
|
"best_of": 1,
|
||||||
|
"decoder_input_details": true,
|
||||||
|
"details": true,
|
||||||
|
"do_sample": true,
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": "string"
|
||||||
|
},
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"repetition_penalty": 1.03,
|
||||||
|
"return_full_text": false,
|
||||||
|
"seed": 0.1,
|
||||||
|
"stop": [
|
||||||
|
"photographer"
|
||||||
|
],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_n_tokens": 5,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"truncate": true,
|
||||||
|
"typical_p": 0.95,
|
||||||
|
"watermark": true
|
||||||
|
}
|
||||||
|
}' http://localhost:8000/generate
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample output:
|
||||||
|
```bash
|
||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer "
|
||||||
|
},
|
||||||
|
"finish_reason": "length",
|
||||||
|
"generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ",
|
||||||
|
"generated_tokens": 31
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 4,
|
||||||
|
"total_tokens": 35,
|
||||||
|
"completion_tokens": 31
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using /generate_stream API
|
||||||
|
|
||||||
|
This is to send a sentence as inputs in the request, and a long connection will be opened to continuously receive multiple responses containing model-generated answer.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST -H "Content-Type: application/json" -d '{
|
||||||
|
"inputs": "What is AI?",
|
||||||
|
"parameters": {
|
||||||
|
"best_of": 1,
|
||||||
|
"decoder_input_details": true,
|
||||||
|
"details": true,
|
||||||
|
"do_sample": true,
|
||||||
|
"frequency_penalty": 0.1,
|
||||||
|
"grammar": {
|
||||||
|
"type": "json",
|
||||||
|
"value": "string"
|
||||||
|
},
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"repetition_penalty": 1.03,
|
||||||
|
"return_full_text": false,
|
||||||
|
"seed": 0.1,
|
||||||
|
"stop": [
|
||||||
|
"photographer"
|
||||||
|
],
|
||||||
|
"temperature": 0.5,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_n_tokens": 5,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"truncate": true,
|
||||||
|
"typical_p": 0.95,
|
||||||
|
"watermark": true
|
||||||
|
}
|
||||||
|
}' http://localhost:8000/generate_stream
|
||||||
|
```
|
||||||
|
|
||||||
|
Sample output:
|
||||||
|
```bash
|
||||||
|
data: {"token": {"id": 663359, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 300560, "text": "\n", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 725120, "text": "Artificial Intelligence ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 734609, "text": "(AI) is ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 362235, "text": "a branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 380983, "text": "science that attempts to ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 249979, "text": "simulate the way that ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 972663, "text": "the human brain ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 793301, "text": "works. It is a ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 501380, "text": "branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 673232, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null}
|
||||||
|
|
||||||
|
data: {"token": {"id": 2, "text": "</s>", "logprob": 0.0, "special": true}, "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", "details": {"finish_reason": "eos_token", "generated_tokens": 31, "prefill_tokens": 4, "seed": 2023}, "special_ret": {"tensor": []}}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Launch RESTful API server
|
### Launch RESTful API server
|
||||||
|
|
||||||
To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it.
|
To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it.
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ class BigDLLLMWorker(BaseModelWorker):
|
||||||
speculative,
|
speculative,
|
||||||
load_low_bit_model,
|
load_low_bit_model,
|
||||||
)
|
)
|
||||||
if benchmark.lower() == "true":
|
if benchmark.lower() == "true" and not speculative:
|
||||||
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
|
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
|
||||||
self.model = BenchmarkWrapper(self.model, do_print=True)
|
self.model = BenchmarkWrapper(self.model, do_print=True)
|
||||||
logger.info(f"enable benchmark successfully")
|
logger.info(f"enable benchmark successfully")
|
||||||
|
|
|
||||||
222
python/llm/src/ipex_llm/serving/fastchat/tgi_api_protocol.py
Normal file
222
python/llm/src/ipex_llm/serving/fastchat/tgi_api_protocol.py
Normal file
|
|
@ -0,0 +1,222 @@
|
||||||
|
from typing import Literal, Optional, List, Dict, Any, Union
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import shortuuid
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import sys
|
||||||
|
|
||||||
|
pseudo_infinite_int = sys.maxsize
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
object: str = "error"
|
||||||
|
message: str
|
||||||
|
code: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPermission(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}")
|
||||||
|
object: str = "model_permission"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
allow_create_engine: bool = False
|
||||||
|
allow_sampling: bool = True
|
||||||
|
allow_logprobs: bool = True
|
||||||
|
allow_search_indices: bool = True
|
||||||
|
allow_view: bool = True
|
||||||
|
allow_fine_tuning: bool = False
|
||||||
|
organization: str = "*"
|
||||||
|
group: Optional[str] = None
|
||||||
|
is_blocking: str = False
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCard(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "model"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
owned_by: str = "fastchat"
|
||||||
|
root: Optional[str] = None
|
||||||
|
parent: Optional[str] = None
|
||||||
|
permission: List[ModelPermission] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ModelList(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
|
class UsageInfo(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class LogProbs(BaseModel):
|
||||||
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
|
tokens: List[str] = Field(default_factory=list)
|
||||||
|
top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Grammar(BaseModel):
|
||||||
|
type: str = "json"
|
||||||
|
value: str = "string"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionParam(BaseModel):
|
||||||
|
best_of: Optional[int] = 1
|
||||||
|
decoder_input_details: Optional[bool] = True
|
||||||
|
details: Optional[bool] = False
|
||||||
|
do_sample: Optional[bool] = False
|
||||||
|
frequency_penalty: Optional[float] = 0.1
|
||||||
|
grammar: Optional[Grammar] = Grammar()
|
||||||
|
max_new_tokens: Optional[int] = None
|
||||||
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
return_full_text: Optional[bool] = False
|
||||||
|
seed: Optional[float] = None
|
||||||
|
stop: Optional[List[str]] = []
|
||||||
|
temperature: Optional[float] = 1.0
|
||||||
|
top_k: Optional[int] = pseudo_infinite_int
|
||||||
|
top_n_tokens: Optional[int] = 5
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
truncate: Optional[bool] = False
|
||||||
|
typical_p: Optional[float] = 0.95
|
||||||
|
watermark: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
inputs: str
|
||||||
|
parameters: ChatCompletionParam
|
||||||
|
model: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
generated_text: str
|
||||||
|
generated_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionDetails(BaseModel):
|
||||||
|
best_of_sequences: List[ChatCompletionChoice]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionResponse(BaseModel):
|
||||||
|
details: ChatCompletionDetails
|
||||||
|
generated_text: str
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class DeltaMessage(BaseModel):
|
||||||
|
role: Optional[str] = None
|
||||||
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: DeltaMessage
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamDetails(BaseModel):
|
||||||
|
best_of_sequences: List[ChatCompletionStreamChoice]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}")
|
||||||
|
object: str = "chat.completion.chunk"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
details: ChatCompletionStreamDetails
|
||||||
|
generated_text: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCheckRequestItem(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: str
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCheckRequest(BaseModel):
|
||||||
|
prompts: List[TokenCheckRequestItem]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCheckResponseItem(BaseModel):
|
||||||
|
fits: bool
|
||||||
|
tokenCount: int
|
||||||
|
contextLength: int
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCheckResponse(BaseModel):
|
||||||
|
prompts: List[TokenCheckResponseItem]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
model: Optional[str] = None
|
||||||
|
engine: Optional[str] = None
|
||||||
|
input: Union[str, List[Any]]
|
||||||
|
user: Optional[str] = None
|
||||||
|
encoding_format: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: List[Dict[str, Any]]
|
||||||
|
model: str
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
prompt: Union[str, List[Any]]
|
||||||
|
suffix: Optional[str] = None
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
n: Optional[int] = 1
|
||||||
|
max_tokens: Optional[int] = 16
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
top_p: Optional[float] = 1.0
|
||||||
|
top_k: Optional[int] = -1
|
||||||
|
logprobs: Optional[int] = None
|
||||||
|
echo: Optional[bool] = False
|
||||||
|
presence_penalty: Optional[float] = 0.0
|
||||||
|
frequency_penalty: Optional[float] = 0.0
|
||||||
|
user: Optional[str] = None
|
||||||
|
use_beam_search: Optional[bool] = False
|
||||||
|
best_of: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseChoice]
|
||||||
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionResponseStreamChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[LogProbs] = None
|
||||||
|
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionStreamResponse(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}")
|
||||||
|
object: str = "text_completion"
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
choices: List[CompletionResponseStreamChoice]
|
||||||
758
python/llm/src/ipex_llm/serving/fastchat/tgi_api_server
Normal file
758
python/llm/src/ipex_llm/serving/fastchat/tgi_api_server
Normal file
|
|
@ -0,0 +1,758 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""
|
||||||
|
Modified from
|
||||||
|
https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||||
|
|
||||||
|
A server that provides OpenAI-compatible RESTful APIs. It supports:
|
||||||
|
|
||||||
|
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
|
||||||
|
- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions)
|
||||||
|
- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m fastchat.serve.openai_api_server
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Generator, Optional, Union, Dict, List, Any
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import fastapi
|
||||||
|
from fastapi import Depends, HTTPException
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pydantic.v1 import BaseSettings
|
||||||
|
except ImportError:
|
||||||
|
from pydantic import BaseSettings
|
||||||
|
import shortuuid
|
||||||
|
import tiktoken
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from fastchat.constants import (
|
||||||
|
WORKER_API_TIMEOUT,
|
||||||
|
WORKER_API_EMBEDDING_BATCH_SIZE,
|
||||||
|
ErrorCode,
|
||||||
|
)
|
||||||
|
from fastchat.conversation import Conversation, SeparatorStyle
|
||||||
|
from .tgi_api_protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionStreamChoice,
|
||||||
|
ChatCompletionStreamDetails,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
ChatMessage,
|
||||||
|
ChatCompletionChoice,
|
||||||
|
CompletionRequest,
|
||||||
|
CompletionResponse,
|
||||||
|
CompletionResponseChoice,
|
||||||
|
DeltaMessage,
|
||||||
|
CompletionResponseStreamChoice,
|
||||||
|
CompletionStreamResponse,
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
ErrorResponse,
|
||||||
|
LogProbs,
|
||||||
|
ModelCard,
|
||||||
|
ModelList,
|
||||||
|
ModelPermission,
|
||||||
|
UsageInfo,
|
||||||
|
ChatCompletionParam,
|
||||||
|
ChatCompletionDetails,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
)
|
||||||
|
from fastchat.protocol.api_protocol import (
|
||||||
|
APIChatCompletionRequest,
|
||||||
|
APITokenCheckRequest,
|
||||||
|
APITokenCheckResponse,
|
||||||
|
APITokenCheckResponseItem,
|
||||||
|
)
|
||||||
|
from fastchat.utils import build_logger
|
||||||
|
|
||||||
|
logger = build_logger("tgi_api_server", "tgi_api_server.log")
|
||||||
|
|
||||||
|
conv_template_map = {}
|
||||||
|
|
||||||
|
fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||||
|
|
||||||
|
async def fetch_remote(url, pload=None, name=None):
|
||||||
|
async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
|
||||||
|
async with session.post(url, json=pload) as response:
|
||||||
|
chunks = []
|
||||||
|
if response.status != 200:
|
||||||
|
ret = {
|
||||||
|
"text": f"{response.reason}",
|
||||||
|
"error_code": ErrorCode.INTERNAL_ERROR,
|
||||||
|
}
|
||||||
|
return json.dumps(ret)
|
||||||
|
|
||||||
|
async for chunk, _ in response.content.iter_chunks():
|
||||||
|
chunks.append(chunk)
|
||||||
|
output = b"".join(chunks)
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
res = json.loads(output)
|
||||||
|
if name != "":
|
||||||
|
res = res[name]
|
||||||
|
return res
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class AppSettings(BaseSettings):
|
||||||
|
# The address of the model controller.
|
||||||
|
controller_address: str = "http://localhost:21001"
|
||||||
|
api_keys: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
app_settings = AppSettings()
|
||||||
|
app = fastapi.FastAPI()
|
||||||
|
headers = {"User-Agent": "FastChat API Server"}
|
||||||
|
get_bearer_token = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_api_key(
|
||||||
|
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
|
||||||
|
) -> str:
|
||||||
|
if app_settings.api_keys:
|
||||||
|
if auth is None or (token := auth.credentials) not in app_settings.api_keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail={
|
||||||
|
"error": {
|
||||||
|
"message": "",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": None,
|
||||||
|
"code": "invalid_api_key",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return token
|
||||||
|
else:
|
||||||
|
# api_keys not set; allow all
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_response(code: int, message: str) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
ErrorResponse(message=message, code=code).dict(), status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.exception_handler(RequestValidationError)
|
||||||
|
async def validation_exception_handler(request, exc):
|
||||||
|
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
async def check_model(request) -> Optional[JSONResponse]:
|
||||||
|
controller_address = app_settings.controller_address
|
||||||
|
ret = None
|
||||||
|
|
||||||
|
models = await fetch_remote(controller_address + "/list_models", None, "models")
|
||||||
|
if request.model not in models:
|
||||||
|
ret = create_error_response(
|
||||||
|
ErrorCode.INVALID_MODEL,
|
||||||
|
f"Only {'&&'.join(models)} allowed now, your model {request.model}",
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
async def check_length(request, prompt, max_tokens, worker_addr):
|
||||||
|
if (
|
||||||
|
not isinstance(max_tokens, int) or max_tokens <= 0
|
||||||
|
): # model worker not support max_tokens=None
|
||||||
|
max_tokens = 1024 * 1024
|
||||||
|
|
||||||
|
context_len = await fetch_remote(
|
||||||
|
worker_addr + "/model_details", {"model": request.model}, "context_length"
|
||||||
|
)
|
||||||
|
token_num = await fetch_remote(
|
||||||
|
worker_addr + "/count_token",
|
||||||
|
{"model": request.model, "prompt": prompt},
|
||||||
|
"count",
|
||||||
|
)
|
||||||
|
length = min(max_tokens, context_len - token_num)
|
||||||
|
|
||||||
|
if length <= 0:
|
||||||
|
return None, create_error_response(
|
||||||
|
ErrorCode.CONTEXT_OVERFLOW,
|
||||||
|
f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return length, None
|
||||||
|
|
||||||
|
|
||||||
|
def check_requests(request) -> Optional[JSONResponse]:
|
||||||
|
# Check all params
|
||||||
|
if request.parameters.max_new_tokens is not None and request.parameters.max_new_tokens <= 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.max_new_tokens} is less than the minimum of 1 - 'max_tokens'",
|
||||||
|
)
|
||||||
|
if request.parameters.best_of is not None and request.parameters.best_of <= 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.best_of} is less than the minimum of 1 - 'n'",
|
||||||
|
)
|
||||||
|
if request.parameters.temperature is not None and request.parameters.temperature < 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.temperature} is less than the minimum of 0 - 'temperature'",
|
||||||
|
)
|
||||||
|
if request.parameters.temperature is not None and request.parameters.temperature > 2:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.temperature} is greater than the maximum of 2 - 'temperature'",
|
||||||
|
)
|
||||||
|
if request.parameters.top_p is not None and request.parameters.top_p < 0:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.top_p} is less than the minimum of 0 - 'top_p'",
|
||||||
|
)
|
||||||
|
if request.parameters.top_p is not None and request.parameters.top_p > 1:
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.top_p} is greater than the maximum of 1 - 'top_p'",
|
||||||
|
)
|
||||||
|
if request.parameters.top_k is not None and (request.parameters.top_k > -1 and request.parameters.top_k < 1):
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.top_k} is out of Range. Either set top_k to -1 or >=1.",
|
||||||
|
)
|
||||||
|
if request.parameters.stop is not None and (
|
||||||
|
not isinstance(request.parameters.stop, str) and not isinstancerequest.parameters.stop, list
|
||||||
|
):
|
||||||
|
return create_error_response(
|
||||||
|
ErrorCode.PARAM_OUT_OF_RANGE,
|
||||||
|
f"{request.parameters.stop} is not valid under any of the given schemas - 'stop'",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def process_input(model_name, inp):
|
||||||
|
if isinstance(inp, str):
|
||||||
|
inp = [inp]
|
||||||
|
elif isinstance(inp, list):
|
||||||
|
if isinstance(inp[0], int):
|
||||||
|
try:
|
||||||
|
decoding = tiktoken.model.encoding_for_model(model_name)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
decoding = tiktoken.get_encoding(model)
|
||||||
|
inp = [decoding.decode(inp)]
|
||||||
|
elif isinstance(inp[0], list):
|
||||||
|
try:
|
||||||
|
decoding = tiktoken.model.encoding_for_model(model_name)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
||||||
|
model = "cl100k_base"
|
||||||
|
decoding = tiktoken.get_encoding(model)
|
||||||
|
inp = [decoding.decode(text) for text in inp]
|
||||||
|
|
||||||
|
return inp
|
||||||
|
|
||||||
|
|
||||||
|
def create_openai_logprobs(logprob_dict):
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
return LogProbs(**logprob_dict) if logprob_dict is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def _add_to_set(s, new_stop):
|
||||||
|
if not s:
|
||||||
|
return
|
||||||
|
if isinstance(s, str):
|
||||||
|
new_stop.add(s)
|
||||||
|
else:
|
||||||
|
new_stop.update(s)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_gen_params(
|
||||||
|
model_name: str,
|
||||||
|
worker_addr: str,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
*,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: Optional[int],
|
||||||
|
presence_penalty: Optional[float],
|
||||||
|
frequency_penalty: Optional[float],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
echo: Optional[bool],
|
||||||
|
logprobs: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]],
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
use_beam_search: Optional[bool] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
conv = await get_conv(model_name, worker_addr)
|
||||||
|
conv = Conversation(
|
||||||
|
name=conv["name"],
|
||||||
|
system_template=conv["system_template"],
|
||||||
|
system_message=conv["system_message"],
|
||||||
|
roles=conv["roles"],
|
||||||
|
messages=list(conv["messages"]), # prevent in-place modification
|
||||||
|
offset=conv["offset"],
|
||||||
|
sep_style=SeparatorStyle(conv["sep_style"]),
|
||||||
|
sep=conv["sep"],
|
||||||
|
sep2=conv["sep2"],
|
||||||
|
stop_str=conv["stop_str"],
|
||||||
|
stop_token_ids=conv["stop_token_ids"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(messages, str):
|
||||||
|
prompt = messages
|
||||||
|
images = []
|
||||||
|
else:
|
||||||
|
for message in messages:
|
||||||
|
msg_role = message["role"]
|
||||||
|
if msg_role == "system":
|
||||||
|
conv.set_system_message(message["content"])
|
||||||
|
elif msg_role == "user":
|
||||||
|
if type(message["content"]) == list:
|
||||||
|
image_list = [
|
||||||
|
item["image_url"]["url"]
|
||||||
|
for item in message["content"]
|
||||||
|
if item["type"] == "image_url"
|
||||||
|
]
|
||||||
|
text_list = [
|
||||||
|
item["text"]
|
||||||
|
for item in message["content"]
|
||||||
|
if item["type"] == "text"
|
||||||
|
]
|
||||||
|
|
||||||
|
text = "\n".join(text_list)
|
||||||
|
conv.append_message(conv.roles[0], (text, image_list))
|
||||||
|
else:
|
||||||
|
conv.append_message(conv.roles[0], message["content"])
|
||||||
|
elif msg_role == "assistant":
|
||||||
|
conv.append_message(conv.roles[1], message["content"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role: {msg_role}")
|
||||||
|
|
||||||
|
# Add a blank message for the assistant.
|
||||||
|
conv.append_message(conv.roles[1], None)
|
||||||
|
prompt = conv.get_prompt()
|
||||||
|
images = conv.get_images()
|
||||||
|
|
||||||
|
gen_params = {
|
||||||
|
"model": model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"temperature": temperature,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
"top_p": top_p,
|
||||||
|
"top_k": top_k,
|
||||||
|
"presence_penalty": presence_penalty,
|
||||||
|
"frequency_penalty": frequency_penalty,
|
||||||
|
"max_new_tokens": max_tokens,
|
||||||
|
"echo": echo,
|
||||||
|
"stop_token_ids": conv.stop_token_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(images) > 0:
|
||||||
|
gen_params["images"] = images
|
||||||
|
|
||||||
|
if best_of is not None:
|
||||||
|
gen_params.update({"best_of": best_of})
|
||||||
|
if use_beam_search is not None:
|
||||||
|
gen_params.update({"use_beam_search": use_beam_search})
|
||||||
|
|
||||||
|
new_stop = set()
|
||||||
|
_add_to_set(stop, new_stop)
|
||||||
|
_add_to_set(conv.stop_str, new_stop)
|
||||||
|
|
||||||
|
gen_params["stop"] = list(new_stop)
|
||||||
|
|
||||||
|
logger.debug(f"==== request ====\n{gen_params}")
|
||||||
|
return gen_params
|
||||||
|
|
||||||
|
|
||||||
|
async def get_worker_address(model_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get worker address based on the requested model
|
||||||
|
|
||||||
|
:param model_name: The worker's model name
|
||||||
|
:return: Worker address from the controller
|
||||||
|
:raises: :class:`ValueError`: No available worker for requested model
|
||||||
|
"""
|
||||||
|
controller_address = app_settings.controller_address
|
||||||
|
worker_addr = await fetch_remote(
|
||||||
|
controller_address + "/get_worker_address", {"model": model_name}, "address"
|
||||||
|
)
|
||||||
|
|
||||||
|
# No available worker
|
||||||
|
if worker_addr == "":
|
||||||
|
raise ValueError(f"No available worker for {model_name}")
|
||||||
|
logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
||||||
|
return worker_addr
|
||||||
|
|
||||||
|
|
||||||
|
async def get_conv(model_name: str, worker_addr: str):
|
||||||
|
conv_template = conv_template_map.get((worker_addr, model_name))
|
||||||
|
if conv_template is None:
|
||||||
|
conv_template = await fetch_remote(
|
||||||
|
worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv"
|
||||||
|
)
|
||||||
|
conv_template_map[(worker_addr, model_name)] = conv_template
|
||||||
|
return conv_template
|
||||||
|
|
||||||
|
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||||
|
async def show_available_models():
|
||||||
|
controller_address = app_settings.controller_address
|
||||||
|
ret = await fetch_remote(controller_address + "/refresh_all_workers")
|
||||||
|
models = await fetch_remote(controller_address + "/list_models", None, "models")
|
||||||
|
|
||||||
|
models.sort()
|
||||||
|
# TODO: return real model permission details
|
||||||
|
model_cards = []
|
||||||
|
for m in models:
|
||||||
|
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
|
||||||
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
|
async def get_last_model_name_from_list():
|
||||||
|
models = await show_available_models()
|
||||||
|
return models.data[-1].id
|
||||||
|
|
||||||
|
@app.post("/generate", dependencies=[Depends(check_api_key)])
|
||||||
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
"""Creates a completion for the chat message"""
|
||||||
|
|
||||||
|
request.model = await get_last_model_name_from_list()
|
||||||
|
|
||||||
|
worker_addr = await get_worker_address(request.model)
|
||||||
|
|
||||||
|
gen_params = await get_gen_params(
|
||||||
|
request.model,
|
||||||
|
worker_addr,
|
||||||
|
request.inputs,
|
||||||
|
temperature=request.parameters.temperature,
|
||||||
|
top_p=request.parameters.top_p,
|
||||||
|
top_k=request.parameters.top_k,
|
||||||
|
presence_penalty=request.parameters.repetition_penalty,
|
||||||
|
frequency_penalty=request.parameters.frequency_penalty,
|
||||||
|
max_tokens=request.parameters.max_new_tokens,
|
||||||
|
echo=False,
|
||||||
|
stop=request.parameters.stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_new_tokens, error_check_ret = await check_length(
|
||||||
|
request,
|
||||||
|
gen_params["prompt"],
|
||||||
|
gen_params["max_new_tokens"],
|
||||||
|
worker_addr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
gen_params["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
choices = []
|
||||||
|
chat_completions = []
|
||||||
|
for i in range(request.parameters.best_of):
|
||||||
|
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
|
||||||
|
chat_completions.append(content)
|
||||||
|
try:
|
||||||
|
all_tasks = await asyncio.gather(*chat_completions)
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
|
||||||
|
usage = UsageInfo()
|
||||||
|
for i, content in enumerate(all_tasks):
|
||||||
|
if isinstance(content, str):
|
||||||
|
content = json.loads(content)
|
||||||
|
|
||||||
|
if content["error_code"] != 0:
|
||||||
|
return create_error_response(content["error_code"], content["text"])
|
||||||
|
choices.append(
|
||||||
|
ChatCompletionChoice(
|
||||||
|
index=i,
|
||||||
|
message=ChatMessage(role="assistant", content=content["text"]),
|
||||||
|
finish_reason=content.get("finish_reason", "stop"),
|
||||||
|
generated_text=content["text"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "usage" in content:
|
||||||
|
task_usage = UsageInfo.parse_obj(content["usage"])
|
||||||
|
choices[-1].generated_tokens = task_usage.dict()["completion_tokens"]
|
||||||
|
for usage_key, usage_value in task_usage.dict().items():
|
||||||
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||||||
|
|
||||||
|
details = ChatCompletionDetails(
|
||||||
|
best_of_sequences=choices
|
||||||
|
)
|
||||||
|
generated_text = choices[0].message.content
|
||||||
|
|
||||||
|
return ChatCompletionResponse(details=details, usage=usage, generated_text=generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/generate_stream", dependencies=[Depends(check_api_key)])
|
||||||
|
async def create_chat_completion_stream(request: ChatCompletionRequest):
|
||||||
|
"""Creates a completion for the chat message"""
|
||||||
|
|
||||||
|
request.model = await get_last_model_name_from_list()
|
||||||
|
|
||||||
|
worker_addr = await get_worker_address(request.model)
|
||||||
|
|
||||||
|
gen_params = await get_gen_params(
|
||||||
|
request.model,
|
||||||
|
worker_addr,
|
||||||
|
request.inputs,
|
||||||
|
temperature=request.parameters.temperature,
|
||||||
|
top_p=request.parameters.top_p,
|
||||||
|
top_k=request.parameters.top_k,
|
||||||
|
presence_penalty=request.parameters.repetition_penalty,
|
||||||
|
frequency_penalty=request.parameters.frequency_penalty,
|
||||||
|
max_tokens=request.parameters.max_new_tokens,
|
||||||
|
echo=False,
|
||||||
|
stop=request.parameters.stop,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_new_tokens, error_check_ret = await check_length(
|
||||||
|
request,
|
||||||
|
gen_params["prompt"],
|
||||||
|
gen_params["max_new_tokens"],
|
||||||
|
worker_addr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
|
gen_params["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
generator = chat_completion_stream_generator(
|
||||||
|
request.model, gen_params, request.parameters.best_of, worker_addr
|
||||||
|
)
|
||||||
|
return StreamingResponse(generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_stream_generator(
|
||||||
|
model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str
|
||||||
|
) -> Generator[str, Any, None]:
|
||||||
|
"""
|
||||||
|
Event stream format:
|
||||||
|
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
|
||||||
|
"""
|
||||||
|
id = f"chatcmpl-{shortuuid.random()}"
|
||||||
|
finish_stream_events = []
|
||||||
|
|
||||||
|
for i in range(n):
|
||||||
|
# First chunk with role
|
||||||
|
choice_data = ChatCompletionStreamChoice(
|
||||||
|
index=i,
|
||||||
|
message=DeltaMessage(role="assistant"),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
details = ChatCompletionStreamDetails(best_of_sequences=[choice_data])
|
||||||
|
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=id, details=details, generated_text=""
|
||||||
|
)
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
previous_text = ""
|
||||||
|
async for content in generate_completion_stream(gen_params, worker_addr):
|
||||||
|
if content["error_code"] != 0:
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
decoded_unicode = content["text"].replace("\ufffd", "")
|
||||||
|
delta_text = decoded_unicode[len(previous_text) :]
|
||||||
|
previous_text = (
|
||||||
|
decoded_unicode
|
||||||
|
if len(decoded_unicode) > len(previous_text)
|
||||||
|
else previous_text
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(delta_text) == 0:
|
||||||
|
delta_text = None
|
||||||
|
choice_data = ChatCompletionStreamChoice(
|
||||||
|
index=i,
|
||||||
|
message=DeltaMessage(content=delta_text),
|
||||||
|
finish_reason=content.get("finish_reason", None),
|
||||||
|
)
|
||||||
|
details = ChatCompletionStreamDetails(best_of_sequences=[choice_data])
|
||||||
|
print(f"type of delta_text: {type(delta_text)}")
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=id,
|
||||||
|
details=details,
|
||||||
|
generated_text=delta_text
|
||||||
|
)
|
||||||
|
if delta_text is None:
|
||||||
|
if content.get("finish_reason", None) is not None:
|
||||||
|
finish_stream_events.append(chunk)
|
||||||
|
continue
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||||
|
for finish_chunk in finish_stream_events:
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
async def generate_completion_stream_generator(
|
||||||
|
request: CompletionRequest, n: int, worker_addr: str
|
||||||
|
):
|
||||||
|
model_name = request.model
|
||||||
|
id = f"cmpl-{shortuuid.random()}"
|
||||||
|
finish_stream_events = []
|
||||||
|
for text in request.prompt:
|
||||||
|
for i in range(n):
|
||||||
|
previous_text = ""
|
||||||
|
gen_params = await get_gen_params(
|
||||||
|
request.model,
|
||||||
|
worker_addr,
|
||||||
|
text,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_p=request.top_p,
|
||||||
|
top_k=request.top_k,
|
||||||
|
presence_penalty=request.presence_penalty,
|
||||||
|
frequency_penalty=request.frequency_penalty,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
logprobs=request.logprobs,
|
||||||
|
echo=request.echo,
|
||||||
|
stop=request.stop,
|
||||||
|
)
|
||||||
|
async for content in generate_completion_stream(gen_params, worker_addr):
|
||||||
|
if content["error_code"] != 0:
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
return
|
||||||
|
decoded_unicode = content["text"].replace("\ufffd", "")
|
||||||
|
delta_text = decoded_unicode[len(previous_text) :]
|
||||||
|
previous_text = (
|
||||||
|
decoded_unicode
|
||||||
|
if len(decoded_unicode) > len(previous_text)
|
||||||
|
else previous_text
|
||||||
|
)
|
||||||
|
# todo: index is not apparent
|
||||||
|
choice_data = CompletionResponseStreamChoice(
|
||||||
|
index=i,
|
||||||
|
text=delta_text,
|
||||||
|
logprobs=create_openai_logprobs(content.get("logprobs", None)),
|
||||||
|
finish_reason=content.get("finish_reason", None),
|
||||||
|
)
|
||||||
|
chunk = CompletionStreamResponse(
|
||||||
|
id=id,
|
||||||
|
object="text_completion",
|
||||||
|
choices=[choice_data],
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
if len(delta_text) == 0:
|
||||||
|
if content.get("finish_reason", None) is not None:
|
||||||
|
finish_stream_events.append(chunk)
|
||||||
|
continue
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
|
||||||
|
for finish_chunk in finish_stream_events:
|
||||||
|
yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str):
|
||||||
|
controller_address = app_settings.controller_address
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
delimiter = b"\0"
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
worker_addr + "/worker_generate_stream",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=WORKER_API_TIMEOUT,
|
||||||
|
) as response:
|
||||||
|
# content = await response.aread()
|
||||||
|
buffer = b""
|
||||||
|
async for raw_chunk in response.aiter_raw():
|
||||||
|
buffer += raw_chunk
|
||||||
|
while (chunk_end := buffer.find(delimiter)) >= 0:
|
||||||
|
chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :]
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
|
yield json.loads(chunk.decode())
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_completion(payload: Dict[str, Any], worker_addr: str):
|
||||||
|
return await fetch_remote(worker_addr + "/worker_generate", payload, "")
|
||||||
|
|
||||||
|
### END GENERAL API - NOT OPENAI COMPATIBLE ###
|
||||||
|
|
||||||
|
|
||||||
|
def create_openai_api_server():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="FastChat ChatGPT-Compatible RESTful API server."
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
||||||
|
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||||
|
parser.add_argument(
|
||||||
|
"--controller-address", type=str, default="http://localhost:21001"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow-credentials", action="store_true", help="allow credentials"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--api-keys",
|
||||||
|
type=lambda s: s.split(","),
|
||||||
|
help="Optional list of comma separated API keys",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ssl",
|
||||||
|
action="store_true",
|
||||||
|
required=False,
|
||||||
|
default=False,
|
||||||
|
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=args.allowed_origins,
|
||||||
|
allow_credentials=args.allow_credentials,
|
||||||
|
allow_methods=args.allowed_methods,
|
||||||
|
allow_headers=args.allowed_headers,
|
||||||
|
)
|
||||||
|
app_settings.controller_address = args.controller_address
|
||||||
|
app_settings.api_keys = args.api_keys
|
||||||
|
|
||||||
|
logger.info(f"args: {args}")
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = create_openai_api_server()
|
||||||
|
if args.ssl:
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
log_level="info",
|
||||||
|
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
||||||
|
ssl_certfile=os.environ["SSL_CERTFILE"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
||||||
Loading…
Reference in a new issue