From a2e1578fd9a887a5c8169ee32ddff60e6a83f769 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Mon, 20 May 2024 09:15:03 +0800 Subject: [PATCH] Merge tgi_api_server to main (#11036) * init * fix style * speculative can not use benchmark * add tgi server readme --- .../doc/LLM/Quickstart/fastchat_quickstart.md | 133 +++ .../src/ipex_llm/serving/fastchat/README.md | 134 ++++ .../serving/fastchat/ipex_llm_worker.py | 2 +- .../serving/fastchat/tgi_api_protocol.py | 222 +++++ .../ipex_llm/serving/fastchat/tgi_api_server | 758 ++++++++++++++++++ 5 files changed, 1248 insertions(+), 1 deletion(-) create mode 100644 python/llm/src/ipex_llm/serving/fastchat/tgi_api_protocol.py create mode 100644 python/llm/src/ipex_llm/serving/fastchat/tgi_api_server diff --git a/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md b/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md index 2eb44ace..3beb6075 100644 --- a/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md +++ b/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md @@ -124,6 +124,139 @@ This is the user interface that users will interact with. By following these steps, you will be able to serve your models using the web UI with IPEX-LLM as the backend. You can open your browser and chat with a model now. +### Launch TGI Style API server + +When you have started the controller and the worker, you can start TGI Style API server as follows: + +```bash +python3 -m ipex_llm.serving.fastchat.tgi_api_server --host localhost --port 8000 +``` +You can use `curl` for observing the output of the api + +#### Using /generate API + +This is to send a sentence as inputs in the request, and is expected to receive a response containing model-generated answer. + +```bash +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "best_of": 1, + "decoder_input_details": true, + "details": true, + "do_sample": true, + "frequency_penalty": 0.1, + "grammar": { + "type": "json", + "value": "string" + }, + "max_new_tokens": 32, + "repetition_penalty": 1.03, + "return_full_text": false, + "seed": 0.1, + "stop": [ + "photographer" + ], + "temperature": 0.5, + "top_k": 10, + "top_n_tokens": 5, + "top_p": 0.95, + "truncate": true, + "typical_p": 0.95, + "watermark": true + } +}' http://localhost:8000/generate +``` + +Sample output: +```bash +{ + "details": { + "best_of_sequences": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer " + }, + "finish_reason": "length", + "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", + "generated_tokens": 31 + } + ] + }, + "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", + "usage": { + "prompt_tokens": 4, + "total_tokens": 35, + "completion_tokens": 31 + } +} +``` + +#### Using /generate_stream API + +This is to send a sentence as inputs in the request, and a long connection will be opened to continuously receive multiple responses containing model-generated answer. + +```bash +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "best_of": 1, + "decoder_input_details": true, + "details": true, + "do_sample": true, + "frequency_penalty": 0.1, + "grammar": { + "type": "json", + "value": "string" + }, + "max_new_tokens": 32, + "repetition_penalty": 1.03, + "return_full_text": false, + "seed": 0.1, + "stop": [ + "photographer" + ], + "temperature": 0.5, + "top_k": 10, + "top_n_tokens": 5, + "top_p": 0.95, + "truncate": true, + "typical_p": 0.95, + "watermark": true + } +}' http://localhost:8000/generate_stream +``` + +Sample output: +```bash +data: {"token": {"id": 663359, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 300560, "text": "\n", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 725120, "text": "Artificial Intelligence ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 734609, "text": "(AI) is ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 362235, "text": "a branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 380983, "text": "science that attempts to ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 249979, "text": "simulate the way that ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 972663, "text": "the human brain ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 793301, "text": "works. It is a ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 501380, "text": "branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 673232, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 2, "text": "", "logprob": 0.0, "special": true}, "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", "details": {"finish_reason": "eos_token", "generated_tokens": 31, "prefill_tokens": 4, "seed": 2023}, "special_ret": {"tensor": []}} +``` + + ### Launch RESTful API server To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it. diff --git a/python/llm/src/ipex_llm/serving/fastchat/README.md b/python/llm/src/ipex_llm/serving/fastchat/README.md index f1f469ae..7204eaeb 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/README.md +++ b/python/llm/src/ipex_llm/serving/fastchat/README.md @@ -124,6 +124,140 @@ This is the user interface that users will interact with. By following these steps, you will be able to serve your models using the web UI with IPEX-LLM as the backend. You can open your browser and chat with a model now. + +### Launch TGI Style API server + +When you have started the controller and the worker, you can start TGI Style API server as follows: + +```bash +python3 -m ipex_llm.serving.fastchat.tgi_api_server --host localhost --port 8000 +``` +You can use `curl` for observing the output of the api + +#### Using /generate API + +This is to send a sentence as inputs in the request, and is expected to receive a response containing model-generated answer. + +```bash +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "best_of": 1, + "decoder_input_details": true, + "details": true, + "do_sample": true, + "frequency_penalty": 0.1, + "grammar": { + "type": "json", + "value": "string" + }, + "max_new_tokens": 32, + "repetition_penalty": 1.03, + "return_full_text": false, + "seed": 0.1, + "stop": [ + "photographer" + ], + "temperature": 0.5, + "top_k": 10, + "top_n_tokens": 5, + "top_p": 0.95, + "truncate": true, + "typical_p": 0.95, + "watermark": true + } +}' http://localhost:8000/generate +``` + +Sample output: +```bash +{ + "details": { + "best_of_sequences": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer " + }, + "finish_reason": "length", + "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", + "generated_tokens": 31 + } + ] + }, + "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", + "usage": { + "prompt_tokens": 4, + "total_tokens": 35, + "completion_tokens": 31 + } +} +``` + +#### Using /generate_stream API + +This is to send a sentence as inputs in the request, and a long connection will be opened to continuously receive multiple responses containing model-generated answer. + +```bash +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "best_of": 1, + "decoder_input_details": true, + "details": true, + "do_sample": true, + "frequency_penalty": 0.1, + "grammar": { + "type": "json", + "value": "string" + }, + "max_new_tokens": 32, + "repetition_penalty": 1.03, + "return_full_text": false, + "seed": 0.1, + "stop": [ + "photographer" + ], + "temperature": 0.5, + "top_k": 10, + "top_n_tokens": 5, + "top_p": 0.95, + "truncate": true, + "typical_p": 0.95, + "watermark": true + } +}' http://localhost:8000/generate_stream +``` + +Sample output: +```bash +data: {"token": {"id": 663359, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 300560, "text": "\n", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 725120, "text": "Artificial Intelligence ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 734609, "text": "(AI) is ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 362235, "text": "a branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 380983, "text": "science that attempts to ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 249979, "text": "simulate the way that ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 972663, "text": "the human brain ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 793301, "text": "works. It is a ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 501380, "text": "branch of computer ", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 673232, "text": "", "logprob": 0.0, "special": false}, "generated_text": null, "details": null, "special_ret": null} + +data: {"token": {"id": 2, "text": "", "logprob": 0.0, "special": true}, "generated_text": "\nArtificial Intelligence (AI) is a branch of computer science that attempts to simulate the way that the human brain works. It is a branch of computer ", "details": {"finish_reason": "eos_token", "generated_tokens": 31, "prefill_tokens": 4, "seed": 2023}, "special_ret": {"tensor": []}} +``` + + ### Launch RESTful API server To start an OpenAI API server that provides compatible APIs using IPEX-LLM backend, you can launch the `openai_api_server` and follow this [doc](https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md) to use it. diff --git a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py index c491be4b..0ee07ba2 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py +++ b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py @@ -104,7 +104,7 @@ class BigDLLLMWorker(BaseModelWorker): speculative, load_low_bit_model, ) - if benchmark.lower() == "true": + if benchmark.lower() == "true" and not speculative: from ipex_llm.utils.benchmark_util import BenchmarkWrapper self.model = BenchmarkWrapper(self.model, do_print=True) logger.info(f"enable benchmark successfully") diff --git a/python/llm/src/ipex_llm/serving/fastchat/tgi_api_protocol.py b/python/llm/src/ipex_llm/serving/fastchat/tgi_api_protocol.py new file mode 100644 index 00000000..ea8e9393 --- /dev/null +++ b/python/llm/src/ipex_llm/serving/fastchat/tgi_api_protocol.py @@ -0,0 +1,222 @@ +from typing import Literal, Optional, List, Dict, Any, Union + +import time + +import shortuuid +from pydantic import BaseModel, Field +import sys + +pseudo_infinite_int = sys.maxsize + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{shortuuid.random()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = True + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "fastchat" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = [] + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class Grammar(BaseModel): + type: str = "json" + value: str = "string" + + +class ChatCompletionParam(BaseModel): + best_of: Optional[int] = 1 + decoder_input_details: Optional[bool] = True + details: Optional[bool] = False + do_sample: Optional[bool] = False + frequency_penalty: Optional[float] = 0.1 + grammar: Optional[Grammar] = Grammar() + max_new_tokens: Optional[int] = None + repetition_penalty: Optional[float] = 1.0 + return_full_text: Optional[bool] = False + seed: Optional[float] = None + stop: Optional[List[str]] = [] + temperature: Optional[float] = 1.0 + top_k: Optional[int] = pseudo_infinite_int + top_n_tokens: Optional[int] = 5 + top_p: Optional[float] = 1.0 + truncate: Optional[bool] = False + typical_p: Optional[float] = 0.95 + watermark: Optional[bool] = True + + +class ChatCompletionRequest(BaseModel): + inputs: str + parameters: ChatCompletionParam + model: Optional[str] = "" + + +class ChatCompletionChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[Literal["stop", "length"]] = None + generated_text: str + generated_tokens: Optional[int] = 0 + + +class ChatCompletionDetails(BaseModel): + best_of_sequences: List[ChatCompletionChoice] + + +class ChatCompletionResponse(BaseModel): + details: ChatCompletionDetails + generated_text: str + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionStreamChoice(BaseModel): + index: int + message: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamDetails(BaseModel): + best_of_sequences: List[ChatCompletionStreamChoice] + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + details: ChatCompletionStreamDetails + generated_text: Optional[str] + + +class TokenCheckRequestItem(BaseModel): + model: str + prompt: str + max_tokens: int + + +class TokenCheckRequest(BaseModel): + prompts: List[TokenCheckRequestItem] + + +class TokenCheckResponseItem(BaseModel): + fits: bool + tokenCount: int + contextLength: int + + +class TokenCheckResponse(BaseModel): + prompts: List[TokenCheckResponseItem] + + +class EmbeddingsRequest(BaseModel): + model: Optional[str] = None + engine: Optional[str] = None + input: Union[str, List[Any]] + user: Optional[str] = None + encoding_format: Optional[str] = None + + +class EmbeddingsResponse(BaseModel): + object: str = "list" + data: List[Dict[str, Any]] + model: str + usage: UsageInfo + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[str, List[Any]] + suffix: Optional[str] = None + temperature: Optional[float] = 0.7 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + top_p: Optional[float] = 1.0 + top_k: Optional[int] = -1 + logprobs: Optional[int] = None + echo: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + use_beam_search: Optional[bool] = False + best_of: Optional[int] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] diff --git a/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server b/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server new file mode 100644 index 00000000..499239c9 --- /dev/null +++ b/python/llm/src/ipex_llm/serving/fastchat/tgi_api_server @@ -0,0 +1,758 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Modified from +https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py + +A server that provides OpenAI-compatible RESTful APIs. It supports: + +- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) +- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) +- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) + +Usage: +python3 -m fastchat.serve.openai_api_server +""" +import asyncio +import argparse +import json +import os +from typing import Generator, Optional, Union, Dict, List, Any + +import aiohttp +import fastapi +from fastapi import Depends, HTTPException +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +import httpx + +try: + from pydantic.v1 import BaseSettings +except ImportError: + from pydantic import BaseSettings +import shortuuid +import tiktoken +import uvicorn + +from fastchat.constants import ( + WORKER_API_TIMEOUT, + WORKER_API_EMBEDDING_BATCH_SIZE, + ErrorCode, +) +from fastchat.conversation import Conversation, SeparatorStyle +from .tgi_api_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamChoice, + ChatCompletionStreamDetails, + ChatCompletionStreamResponse, + ChatMessage, + ChatCompletionChoice, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + CompletionResponseStreamChoice, + CompletionStreamResponse, + EmbeddingsRequest, + EmbeddingsResponse, + ErrorResponse, + LogProbs, + ModelCard, + ModelList, + ModelPermission, + UsageInfo, + ChatCompletionParam, + ChatCompletionDetails, + ChatCompletionResponse, +) +from fastchat.protocol.api_protocol import ( + APIChatCompletionRequest, + APITokenCheckRequest, + APITokenCheckResponse, + APITokenCheckResponseItem, +) +from fastchat.utils import build_logger + +logger = build_logger("tgi_api_server", "tgi_api_server.log") + +conv_template_map = {} + +fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) + +async def fetch_remote(url, pload=None, name=None): + async with aiohttp.ClientSession(timeout=fetch_timeout) as session: + async with session.post(url, json=pload) as response: + chunks = [] + if response.status != 200: + ret = { + "text": f"{response.reason}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + return json.dumps(ret) + + async for chunk, _ in response.content.iter_chunks(): + chunks.append(chunk) + output = b"".join(chunks) + + if name is not None: + res = json.loads(output) + if name != "": + res = res[name] + return res + + return output + + +class AppSettings(BaseSettings): + # The address of the model controller. + controller_address: str = "http://localhost:21001" + api_keys: Optional[List[str]] = None + + +app_settings = AppSettings() +app = fastapi.FastAPI() +headers = {"User-Agent": "FastChat API Server"} +get_bearer_token = HTTPBearer(auto_error=False) + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), +) -> str: + if app_settings.api_keys: + if auth is None or (token := auth.credentials) not in app_settings.api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse( + ErrorResponse(message=message, code=code).dict(), status_code=400 + ) + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request, exc): + return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) + + +async def check_model(request) -> Optional[JSONResponse]: + controller_address = app_settings.controller_address + ret = None + + models = await fetch_remote(controller_address + "/list_models", None, "models") + if request.model not in models: + ret = create_error_response( + ErrorCode.INVALID_MODEL, + f"Only {'&&'.join(models)} allowed now, your model {request.model}", + ) + return ret + + +async def check_length(request, prompt, max_tokens, worker_addr): + if ( + not isinstance(max_tokens, int) or max_tokens <= 0 + ): # model worker not support max_tokens=None + max_tokens = 1024 * 1024 + + context_len = await fetch_remote( + worker_addr + "/model_details", {"model": request.model}, "context_length" + ) + token_num = await fetch_remote( + worker_addr + "/count_token", + {"model": request.model, "prompt": prompt}, + "count", + ) + length = min(max_tokens, context_len - token_num) + + if length <= 0: + return None, create_error_response( + ErrorCode.CONTEXT_OVERFLOW, + f"This model's maximum context length is {context_len} tokens. However, your messages resulted in {token_num} tokens. Please reduce the length of the messages.", + ) + + return length, None + + +def check_requests(request) -> Optional[JSONResponse]: + # Check all params + if request.parameters.max_new_tokens is not None and request.parameters.max_new_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.max_new_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.parameters.best_of is not None and request.parameters.best_of <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.best_of} is less than the minimum of 1 - 'n'", + ) + if request.parameters.temperature is not None and request.parameters.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.parameters.temperature is not None and request.parameters.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.parameters.top_p is not None and request.parameters.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.parameters.top_p is not None and request.parameters.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.top_p} is greater than the maximum of 1 - 'top_p'", + ) + if request.parameters.top_k is not None and (request.parameters.top_k > -1 and request.parameters.top_k < 1): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.top_k} is out of Range. Either set top_k to -1 or >=1.", + ) + if request.parameters.stop is not None and ( + not isinstance(request.parameters.stop, str) and not isinstancerequest.parameters.stop, list + ): + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.parameters.stop} is not valid under any of the given schemas - 'stop'", + ) + + return None + + +def process_input(model_name, inp): + if isinstance(inp, str): + inp = [inp] + elif isinstance(inp, list): + if isinstance(inp[0], int): + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) + inp = [decoding.decode(inp)] + elif isinstance(inp[0], list): + try: + decoding = tiktoken.model.encoding_for_model(model_name) + except KeyError: + logger.warning("Warning: model not found. Using cl100k_base encoding.") + model = "cl100k_base" + decoding = tiktoken.get_encoding(model) + inp = [decoding.decode(text) for text in inp] + + return inp + + +def create_openai_logprobs(logprob_dict): + """Create OpenAI-style logprobs.""" + return LogProbs(**logprob_dict) if logprob_dict is not None else None + + +def _add_to_set(s, new_stop): + if not s: + return + if isinstance(s, str): + new_stop.add(s) + else: + new_stop.update(s) + + +async def get_gen_params( + model_name: str, + worker_addr: str, + messages: Union[str, List[Dict[str, str]]], + *, + temperature: float, + top_p: float, + top_k: Optional[int], + presence_penalty: Optional[float], + frequency_penalty: Optional[float], + max_tokens: Optional[int], + echo: Optional[bool], + logprobs: Optional[int] = None, + stop: Optional[Union[str, List[str]]], + best_of: Optional[int] = None, + use_beam_search: Optional[bool] = None, +) -> Dict[str, Any]: + conv = await get_conv(model_name, worker_addr) + conv = Conversation( + name=conv["name"], + system_template=conv["system_template"], + system_message=conv["system_message"], + roles=conv["roles"], + messages=list(conv["messages"]), # prevent in-place modification + offset=conv["offset"], + sep_style=SeparatorStyle(conv["sep_style"]), + sep=conv["sep"], + sep2=conv["sep2"], + stop_str=conv["stop_str"], + stop_token_ids=conv["stop_token_ids"], + ) + + if isinstance(messages, str): + prompt = messages + images = [] + else: + for message in messages: + msg_role = message["role"] + if msg_role == "system": + conv.set_system_message(message["content"]) + elif msg_role == "user": + if type(message["content"]) == list: + image_list = [ + item["image_url"]["url"] + for item in message["content"] + if item["type"] == "image_url" + ] + text_list = [ + item["text"] + for item in message["content"] + if item["type"] == "text" + ] + + text = "\n".join(text_list) + conv.append_message(conv.roles[0], (text, image_list)) + else: + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + prompt = conv.get_prompt() + images = conv.get_images() + + gen_params = { + "model": model_name, + "prompt": prompt, + "temperature": temperature, + "logprobs": logprobs, + "top_p": top_p, + "top_k": top_k, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "max_new_tokens": max_tokens, + "echo": echo, + "stop_token_ids": conv.stop_token_ids, + } + + if len(images) > 0: + gen_params["images"] = images + + if best_of is not None: + gen_params.update({"best_of": best_of}) + if use_beam_search is not None: + gen_params.update({"use_beam_search": use_beam_search}) + + new_stop = set() + _add_to_set(stop, new_stop) + _add_to_set(conv.stop_str, new_stop) + + gen_params["stop"] = list(new_stop) + + logger.debug(f"==== request ====\n{gen_params}") + return gen_params + + +async def get_worker_address(model_name: str) -> str: + """ + Get worker address based on the requested model + + :param model_name: The worker's model name + :return: Worker address from the controller + :raises: :class:`ValueError`: No available worker for requested model + """ + controller_address = app_settings.controller_address + worker_addr = await fetch_remote( + controller_address + "/get_worker_address", {"model": model_name}, "address" + ) + + # No available worker + if worker_addr == "": + raise ValueError(f"No available worker for {model_name}") + logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") + return worker_addr + + +async def get_conv(model_name: str, worker_addr: str): + conv_template = conv_template_map.get((worker_addr, model_name)) + if conv_template is None: + conv_template = await fetch_remote( + worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" + ) + conv_template_map[(worker_addr, model_name)] = conv_template + return conv_template + +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +async def show_available_models(): + controller_address = app_settings.controller_address + ret = await fetch_remote(controller_address + "/refresh_all_workers") + models = await fetch_remote(controller_address + "/list_models", None, "models") + + models.sort() + # TODO: return real model permission details + model_cards = [] + for m in models: + model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) + return ModelList(data=model_cards) + +async def get_last_model_name_from_list(): + models = await show_available_models() + return models.data[-1].id + +@app.post("/generate", dependencies=[Depends(check_api_key)]) +async def create_chat_completion(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + + request.model = await get_last_model_name_from_list() + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.inputs, + temperature=request.parameters.temperature, + top_p=request.parameters.top_p, + top_k=request.parameters.top_k, + presence_penalty=request.parameters.repetition_penalty, + frequency_penalty=request.parameters.frequency_penalty, + max_tokens=request.parameters.max_new_tokens, + echo=False, + stop=request.parameters.stop, + ) + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + choices = [] + chat_completions = [] + for i in range(request.parameters.best_of): + content = asyncio.create_task(generate_completion(gen_params, worker_addr)) + chat_completions.append(content) + try: + all_tasks = await asyncio.gather(*chat_completions) + except Exception as e: + return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) + usage = UsageInfo() + for i, content in enumerate(all_tasks): + if isinstance(content, str): + content = json.loads(content) + + if content["error_code"] != 0: + return create_error_response(content["error_code"], content["text"]) + choices.append( + ChatCompletionChoice( + index=i, + message=ChatMessage(role="assistant", content=content["text"]), + finish_reason=content.get("finish_reason", "stop"), + generated_text=content["text"] + ) + ) + if "usage" in content: + task_usage = UsageInfo.parse_obj(content["usage"]) + choices[-1].generated_tokens = task_usage.dict()["completion_tokens"] + for usage_key, usage_value in task_usage.dict().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + + details = ChatCompletionDetails( + best_of_sequences=choices + ) + generated_text = choices[0].message.content + + return ChatCompletionResponse(details=details, usage=usage, generated_text=generated_text) + + +@app.post("/generate_stream", dependencies=[Depends(check_api_key)]) +async def create_chat_completion_stream(request: ChatCompletionRequest): + """Creates a completion for the chat message""" + + request.model = await get_last_model_name_from_list() + + worker_addr = await get_worker_address(request.model) + + gen_params = await get_gen_params( + request.model, + worker_addr, + request.inputs, + temperature=request.parameters.temperature, + top_p=request.parameters.top_p, + top_k=request.parameters.top_k, + presence_penalty=request.parameters.repetition_penalty, + frequency_penalty=request.parameters.frequency_penalty, + max_tokens=request.parameters.max_new_tokens, + echo=False, + stop=request.parameters.stop, + ) + + max_new_tokens, error_check_ret = await check_length( + request, + gen_params["prompt"], + gen_params["max_new_tokens"], + worker_addr, + ) + + if error_check_ret is not None: + return error_check_ret + + gen_params["max_new_tokens"] = max_new_tokens + + generator = chat_completion_stream_generator( + request.model, gen_params, request.parameters.best_of, worker_addr + ) + return StreamingResponse(generator, media_type="text/event-stream") + + +async def chat_completion_stream_generator( + model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str +) -> Generator[str, Any, None]: + """ + Event stream format: + https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format + """ + id = f"chatcmpl-{shortuuid.random()}" + finish_stream_events = [] + + for i in range(n): + # First chunk with role + choice_data = ChatCompletionStreamChoice( + index=i, + message=DeltaMessage(role="assistant"), + finish_reason=None, + ) + details = ChatCompletionStreamDetails(best_of_sequences=[choice_data]) + + chunk = ChatCompletionStreamResponse( + id=id, details=details, generated_text="" + ) + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + + previous_text = "" + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + + if len(delta_text) == 0: + delta_text = None + choice_data = ChatCompletionStreamChoice( + index=i, + message=DeltaMessage(content=delta_text), + finish_reason=content.get("finish_reason", None), + ) + details = ChatCompletionStreamDetails(best_of_sequences=[choice_data]) + print(f"type of delta_text: {type(delta_text)}") + chunk = ChatCompletionStreamResponse( + id=id, + details=details, + generated_text=delta_text + ) + if delta_text is None: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + +async def generate_completion_stream_generator( + request: CompletionRequest, n: int, worker_addr: str +): + model_name = request.model + id = f"cmpl-{shortuuid.random()}" + finish_stream_events = [] + for text in request.prompt: + for i in range(n): + previous_text = "" + gen_params = await get_gen_params( + request.model, + worker_addr, + text, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + max_tokens=request.max_tokens, + logprobs=request.logprobs, + echo=request.echo, + stop=request.stop, + ) + async for content in generate_completion_stream(gen_params, worker_addr): + if content["error_code"] != 0: + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + return + decoded_unicode = content["text"].replace("\ufffd", "") + delta_text = decoded_unicode[len(previous_text) :] + previous_text = ( + decoded_unicode + if len(decoded_unicode) > len(previous_text) + else previous_text + ) + # todo: index is not apparent + choice_data = CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=create_openai_logprobs(content.get("logprobs", None)), + finish_reason=content.get("finish_reason", None), + ) + chunk = CompletionStreamResponse( + id=id, + object="text_completion", + choices=[choice_data], + model=model_name, + ) + if len(delta_text) == 0: + if content.get("finish_reason", None) is not None: + finish_stream_events.append(chunk) + continue + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + # There is not "content" field in the last delta message, so exclude_none to exclude field "content". + for finish_chunk in finish_stream_events: + yield f"data: {json.dumps(chunk.model_dump(exclude_unset=True), ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + +async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): + controller_address = app_settings.controller_address + async with httpx.AsyncClient() as client: + delimiter = b"\0" + async with client.stream( + "POST", + worker_addr + "/worker_generate_stream", + headers=headers, + json=payload, + timeout=WORKER_API_TIMEOUT, + ) as response: + # content = await response.aread() + buffer = b"" + async for raw_chunk in response.aiter_raw(): + buffer += raw_chunk + while (chunk_end := buffer.find(delimiter)) >= 0: + chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] + if not chunk: + continue + yield json.loads(chunk.decode()) + + +async def generate_completion(payload: Dict[str, Any], worker_addr: str): + return await fetch_remote(worker_addr + "/worker_generate", payload, "") + +### END GENERAL API - NOT OPENAI COMPATIBLE ### + + +def create_openai_api_server(): + parser = argparse.ArgumentParser( + description="FastChat ChatGPT-Compatible RESTful API server." + ) + parser.add_argument("--host", type=str, default="localhost", help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--controller-address", type=str, default="http://localhost:21001" + ) + parser.add_argument( + "--allow-credentials", action="store_true", help="allow credentials" + ) + parser.add_argument( + "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" + ) + parser.add_argument( + "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" + ) + parser.add_argument( + "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" + ) + parser.add_argument( + "--api-keys", + type=lambda s: s.split(","), + help="Optional list of comma separated API keys", + ) + parser.add_argument( + "--ssl", + action="store_true", + required=False, + default=False, + help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", + ) + args = parser.parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + app_settings.controller_address = args.controller_address + app_settings.api_keys = args.api_keys + + logger.info(f"args: {args}") + return args + + +if __name__ == "__main__": + args = create_openai_api_server() + if args.ssl: + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + ssl_keyfile=os.environ["SSL_KEYFILE"], + ssl_certfile=os.environ["SSL_CERTFILE"], + ) + else: + uvicorn.run(app, host=args.host, port=args.port, log_level="info") \ No newline at end of file