367 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			367 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Adapted from
 | 
						|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
 | 
						|
import time
 | 
						|
from typing import Dict, List, Literal, Optional, Union
 | 
						|
 | 
						|
import torch
 | 
						|
from openai.types.chat import ChatCompletionMessageParam
 | 
						|
from pydantic import BaseModel, ConfigDict, Field, model_validator
 | 
						|
from typing_extensions import Annotated
 | 
						|
 | 
						|
# from vllm.sampling_params import SamplingParams
 | 
						|
def random_uuid() -> str:
 | 
						|
    return str(uuid.uuid4().hex)
 | 
						|
 | 
						|
class OpenAIBaseModel(BaseModel):
 | 
						|
    # OpenAI API does not allow extra fields
 | 
						|
    model_config = ConfigDict(extra="forbid")
 | 
						|
 | 
						|
 | 
						|
class ErrorResponse(OpenAIBaseModel):
 | 
						|
    object: str = "error"
 | 
						|
    message: str
 | 
						|
    type: str
 | 
						|
    param: Optional[str] = None
 | 
						|
    code: int
 | 
						|
 | 
						|
 | 
						|
class ModelPermission(OpenAIBaseModel):
 | 
						|
    id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
 | 
						|
    object: str = "model_permission"
 | 
						|
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
						|
    allow_create_engine: bool = False
 | 
						|
    allow_sampling: bool = True
 | 
						|
    allow_logprobs: bool = True
 | 
						|
    allow_search_indices: bool = False
 | 
						|
    allow_view: bool = True
 | 
						|
    allow_fine_tuning: bool = False
 | 
						|
    organization: str = "*"
 | 
						|
    group: Optional[str] = None
 | 
						|
    is_blocking: bool = False
 | 
						|
 | 
						|
 | 
						|
class ModelCard(OpenAIBaseModel):
 | 
						|
    id: str
 | 
						|
    object: str = "model"
 | 
						|
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
						|
    owned_by: str = "vllm"
 | 
						|
    root: Optional[str] = None
 | 
						|
    parent: Optional[str] = None
 | 
						|
    permission: List[ModelPermission] = Field(default_factory=list)
 | 
						|
 | 
						|
 | 
						|
class ModelList(OpenAIBaseModel):
 | 
						|
    object: str = "list"
 | 
						|
    data: List[ModelCard] = Field(default_factory=list)
 | 
						|
 | 
						|
 | 
						|
class UsageInfo(OpenAIBaseModel):
 | 
						|
    prompt_tokens: int = 0
 | 
						|
    total_tokens: int = 0
 | 
						|
    completion_tokens: Optional[int] = 0
 | 
						|
 | 
						|
 | 
						|
class ResponseFormat(OpenAIBaseModel):
 | 
						|
    # type must be "json_object" or "text"
 | 
						|
    type: Literal["text", "json_object"]
 | 
						|
 | 
						|
 | 
						|
class ChatCompletionRequest(OpenAIBaseModel):
 | 
						|
    # Ordered by official OpenAI API documentation
 | 
						|
    # https://platform.openai.com/docs/api-reference/chat/create
 | 
						|
    messages: List[ChatCompletionMessageParam]
 | 
						|
    model: str
 | 
						|
    frequency_penalty: Optional[float] = 0.0
 | 
						|
    logit_bias: Optional[Dict[str, float]] = None
 | 
						|
    logprobs: Optional[bool] = False
 | 
						|
    top_logprobs: Optional[int] = None
 | 
						|
    max_tokens: Optional[int] = None
 | 
						|
    n: Optional[int] = 1
 | 
						|
    presence_penalty: Optional[float] = 0.0
 | 
						|
    response_format: Optional[ResponseFormat] = None
 | 
						|
    seed: Optional[int] = Field(None,
 | 
						|
                                ge=torch.iinfo(torch.long).min,
 | 
						|
                                le=torch.iinfo(torch.long).max)
 | 
						|
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
 | 
						|
    stream: Optional[bool] = False
 | 
						|
    temperature: Optional[float] = 0.7
 | 
						|
    top_p: Optional[float] = 1.0
 | 
						|
    user: Optional[str] = None
 | 
						|
 | 
						|
    # doc: begin-chat-completion-sampling-params
 | 
						|
    best_of: Optional[int] = None
 | 
						|
    use_beam_search: Optional[bool] = False
 | 
						|
    top_k: Optional[int] = -1
 | 
						|
    min_p: Optional[float] = 0.0
 | 
						|
    repetition_penalty: Optional[float] = 1.0
 | 
						|
    length_penalty: Optional[float] = 1.0
 | 
						|
    early_stopping: Optional[bool] = False
 | 
						|
    ignore_eos: Optional[bool] = False
 | 
						|
    min_tokens: Optional[int] = 0
 | 
						|
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
 | 
						|
    skip_special_tokens: Optional[bool] = True
 | 
						|
    spaces_between_special_tokens: Optional[bool] = True
 | 
						|
    # doc: end-chat-completion-sampling-params
 | 
						|
 | 
						|
    # doc: begin-chat-completion-extra-params
 | 
						|
    echo: Optional[bool] = Field(
 | 
						|
        default=False,
 | 
						|
        description=(
 | 
						|
            "If true, the new message will be prepended with the last message "
 | 
						|
            "if they belong to the same role."),
 | 
						|
    )
 | 
						|
    add_generation_prompt: Optional[bool] = Field(
 | 
						|
        default=True,
 | 
						|
        description=
 | 
						|
        ("If true, the generation prompt will be added to the chat template. "
 | 
						|
         "This is a parameter used by chat template in tokenizer config of the "
 | 
						|
         "model."),
 | 
						|
    )
 | 
						|
    include_stop_str_in_output: Optional[bool] = Field(
 | 
						|
        default=False,
 | 
						|
        description=(
 | 
						|
            "Whether to include the stop string in the output. "
 | 
						|
            "This is only applied when the stop or stop_token_ids is set."),
 | 
						|
    )
 | 
						|
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
 | 
						|
        default=None,
 | 
						|
        description=("If specified, the output will follow the JSON schema."),
 | 
						|
    )
 | 
						|
    guided_regex: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, the output will follow the regex pattern."),
 | 
						|
    )
 | 
						|
    guided_choice: Optional[List[str]] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, the output will be exactly one of the choices."),
 | 
						|
    )
 | 
						|
    guided_grammar: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, the output will follow the context free grammar."),
 | 
						|
    )
 | 
						|
    guided_decoding_backend: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, will override the default guided decoding backend "
 | 
						|
            "of the server for this specific request. If set, must be either "
 | 
						|
            "'outlines' / 'lm-format-enforcer'"))
 | 
						|
    guided_whitespace_pattern: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, will override the default whitespace pattern "
 | 
						|
            "for guided json decoding."))
 | 
						|
 | 
						|
    # doc: end-chat-completion-extra-params
 | 
						|
    @model_validator(mode="before")
 | 
						|
    @classmethod
 | 
						|
    def check_guided_decoding_count(cls, data):
 | 
						|
        guide_count = sum([
 | 
						|
            "guided_json" in data and data["guided_json"] is not None,
 | 
						|
            "guided_regex" in data and data["guided_regex"] is not None,
 | 
						|
            "guided_choice" in data and data["guided_choice"] is not None
 | 
						|
        ])
 | 
						|
        if guide_count > 1:
 | 
						|
            raise ValueError(
 | 
						|
                "You can only use one kind of guided decoding "
 | 
						|
                "('guided_json', 'guided_regex' or 'guided_choice').")
 | 
						|
        return data
 | 
						|
 | 
						|
 | 
						|
class CompletionRequest(OpenAIBaseModel):
 | 
						|
    # Ordered by official OpenAI API documentation
 | 
						|
    # https://platform.openai.com/docs/api-reference/completions/create
 | 
						|
    model: str
 | 
						|
    prompt: Union[List[int], List[List[int]], str, List[str]]
 | 
						|
    best_of: Optional[int] = None
 | 
						|
    echo: Optional[bool] = False
 | 
						|
    frequency_penalty: Optional[float] = 0.0
 | 
						|
    logit_bias: Optional[Dict[str, float]] = None
 | 
						|
    logprobs: Optional[int] = None
 | 
						|
    max_tokens: Optional[int] = 16
 | 
						|
    n: int = 1
 | 
						|
    presence_penalty: Optional[float] = 0.0
 | 
						|
    seed: Optional[int] = Field(None,
 | 
						|
                                ge=torch.iinfo(torch.long).min,
 | 
						|
                                le=torch.iinfo(torch.long).max)
 | 
						|
    stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
 | 
						|
    stream: Optional[bool] = False
 | 
						|
    suffix: Optional[str] = None
 | 
						|
    temperature: Optional[float] = 1.0
 | 
						|
    top_p: Optional[float] = 1.0
 | 
						|
    user: Optional[str] = None
 | 
						|
 | 
						|
    # doc: begin-completion-sampling-params
 | 
						|
    use_beam_search: Optional[bool] = False
 | 
						|
    top_k: Optional[int] = -1
 | 
						|
    min_p: Optional[float] = 0.0
 | 
						|
    repetition_penalty: Optional[float] = 1.0
 | 
						|
    length_penalty: Optional[float] = 1.0
 | 
						|
    early_stopping: Optional[bool] = False
 | 
						|
    stop_token_ids: Optional[List[int]] = Field(default_factory=list)
 | 
						|
    ignore_eos: Optional[bool] = False
 | 
						|
    min_tokens: Optional[int] = 0
 | 
						|
    skip_special_tokens: Optional[bool] = True
 | 
						|
    spaces_between_special_tokens: Optional[bool] = True
 | 
						|
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
 | 
						|
    # doc: end-completion-sampling-params
 | 
						|
 | 
						|
    # doc: begin-completion-extra-params
 | 
						|
    include_stop_str_in_output: Optional[bool] = Field(
 | 
						|
        default=False,
 | 
						|
        description=(
 | 
						|
            "Whether to include the stop string in the output. "
 | 
						|
            "This is only applied when the stop or stop_token_ids is set."),
 | 
						|
    )
 | 
						|
    response_format: Optional[ResponseFormat] = Field(
 | 
						|
        default=None,
 | 
						|
        description=
 | 
						|
        ("Similar to chat completion, this parameter specifies the format of "
 | 
						|
         "output. Only {'type': 'json_object'} or {'type': 'text' } is "
 | 
						|
         "supported."),
 | 
						|
    )
 | 
						|
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
 | 
						|
        default=None,
 | 
						|
        description=("If specified, the output will follow the JSON schema."),
 | 
						|
    )
 | 
						|
    guided_regex: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, the output will follow the regex pattern."),
 | 
						|
    )
 | 
						|
    guided_choice: Optional[List[str]] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, the output will be exactly one of the choices."),
 | 
						|
    )
 | 
						|
    guided_grammar: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, the output will follow the context free grammar."),
 | 
						|
    )
 | 
						|
    guided_decoding_backend: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, will override the default guided decoding backend "
 | 
						|
            "of the server for this specific request. If set, must be one of "
 | 
						|
            "'outlines' / 'lm-format-enforcer'"))
 | 
						|
    guided_whitespace_pattern: Optional[str] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "If specified, will override the default whitespace pattern "
 | 
						|
            "for guided json decoding."))
 | 
						|
 | 
						|
    # doc: end-completion-extra-params
 | 
						|
 | 
						|
    @model_validator(mode="before")
 | 
						|
    @classmethod
 | 
						|
    def check_guided_decoding_count(cls, data):
 | 
						|
        guide_count = sum([
 | 
						|
            "guided_json" in data and data["guided_json"] is not None,
 | 
						|
            "guided_regex" in data and data["guided_regex"] is not None,
 | 
						|
            "guided_choice" in data and data["guided_choice"] is not None
 | 
						|
        ])
 | 
						|
        if guide_count > 1:
 | 
						|
            raise ValueError(
 | 
						|
                "You can only use one kind of guided decoding "
 | 
						|
                "('guided_json', 'guided_regex' or 'guided_choice').")
 | 
						|
        return data
 | 
						|
 | 
						|
 | 
						|
class LogProbs(OpenAIBaseModel):
 | 
						|
    text_offset: List[int] = Field(default_factory=list)
 | 
						|
    token_logprobs: List[Optional[float]] = Field(default_factory=list)
 | 
						|
    tokens: List[str] = Field(default_factory=list)
 | 
						|
    top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
 | 
						|
 | 
						|
 | 
						|
class CompletionResponseChoice(OpenAIBaseModel):
 | 
						|
    index: int
 | 
						|
    text: str
 | 
						|
    logprobs: Optional[LogProbs] = None
 | 
						|
    finish_reason: Optional[str] = None
 | 
						|
    stop_reason: Optional[Union[int, str]] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "The stop string or token id that caused the completion "
 | 
						|
            "to stop, None if the completion finished for some other reason "
 | 
						|
            "including encountering the EOS token"),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class CompletionResponse(OpenAIBaseModel):
 | 
						|
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
 | 
						|
    object: str = "text_completion"
 | 
						|
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
						|
    model: str
 | 
						|
    choices: List[CompletionResponseChoice]
 | 
						|
    usage: Optional[UsageInfo] = Field(default=None)
 | 
						|
 | 
						|
 | 
						|
class CompletionResponseStreamChoice(OpenAIBaseModel):
 | 
						|
    index: int
 | 
						|
    text: str
 | 
						|
    logprobs: Optional[LogProbs] = None
 | 
						|
    finish_reason: Optional[str] = None
 | 
						|
    stop_reason: Optional[Union[int, str]] = Field(
 | 
						|
        default=None,
 | 
						|
        description=(
 | 
						|
            "The stop string or token id that caused the completion "
 | 
						|
            "to stop, None if the completion finished for some other reason "
 | 
						|
            "including encountering the EOS token"),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
class CompletionStreamResponse(OpenAIBaseModel):
 | 
						|
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
 | 
						|
    object: str = "text_completion"
 | 
						|
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
						|
    model: str
 | 
						|
    choices: List[CompletionResponseStreamChoice]
 | 
						|
    usage: Optional[UsageInfo] = Field(default=None)
 | 
						|
 | 
						|
 | 
						|
class ChatMessage(OpenAIBaseModel):
 | 
						|
    role: str
 | 
						|
    content: str
 | 
						|
 | 
						|
 | 
						|
class ChatCompletionResponseChoice(OpenAIBaseModel):
 | 
						|
    index: int
 | 
						|
    message: ChatMessage
 | 
						|
    logprobs: Optional[LogProbs] = None
 | 
						|
    finish_reason: Optional[str] = None
 | 
						|
    stop_reason: Optional[Union[int, str]] = None
 | 
						|
 | 
						|
 | 
						|
class ChatCompletionResponse(OpenAIBaseModel):
 | 
						|
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
 | 
						|
    object: str = "chat.completion"
 | 
						|
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
						|
    model: str
 | 
						|
    choices: List[ChatCompletionResponseChoice]
 | 
						|
    usage: Optional[UsageInfo] = Field(default=None)
 | 
						|
 | 
						|
 | 
						|
class DeltaMessage(OpenAIBaseModel):
 | 
						|
    role: Optional[str] = None
 | 
						|
    content: Optional[str] = None
 | 
						|
 | 
						|
 | 
						|
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
 | 
						|
    index: int
 | 
						|
    delta: DeltaMessage
 | 
						|
    logprobs: Optional[LogProbs] = None
 | 
						|
    finish_reason: Optional[str] = None
 | 
						|
    stop_reason: Optional[Union[int, str]] = None
 | 
						|
 | 
						|
 | 
						|
class ChatCompletionStreamResponse(OpenAIBaseModel):
 | 
						|
    id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
 | 
						|
    object: str = "chat.completion.chunk"
 | 
						|
    created: int = Field(default_factory=lambda: int(time.time()))
 | 
						|
    model: str
 | 
						|
    choices: List[ChatCompletionResponseStreamChoice]
 | 
						|
    usage: Optional[UsageInfo] = Field(default=None)
 |