Refactor fastapi-serving and add one card serving(#11581)
* init fastapi-serving one card * mv api code to source * update worker * update for style-check * add worker * update bash * update * update worker name and add readme * rename update * rename to fastapi
This commit is contained in:
		
							parent
							
								
									373ccbbb0c
								
							
						
					
					
						commit
						9c15abf825
					
				
					 19 changed files with 583 additions and 367 deletions
				
			
		| 
						 | 
					@ -61,7 +61,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
 | 
				
			||||||
    cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
 | 
					    cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
 | 
				
			||||||
    # Download pp_serving
 | 
					    # Download pp_serving
 | 
				
			||||||
    mkdir -p /llm/pp_serving && \
 | 
					    mkdir -p /llm/pp_serving && \
 | 
				
			||||||
    cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-FastAPI/*.py /llm/pp_serving/ && \
 | 
					    cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-Serving/*.py /llm/pp_serving/ && \
 | 
				
			||||||
    # Install related library of benchmarking
 | 
					    # Install related library of benchmarking
 | 
				
			||||||
    pip install pandas omegaconf && \
 | 
					    pip install pandas omegaconf && \
 | 
				
			||||||
    chmod +x /llm/benchmark.sh && \
 | 
					    chmod +x /llm/benchmark.sh && \
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,346 +0,0 @@
 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Copyright 2016 The BigDL Authors.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
					 | 
				
			||||||
# you may not use this file except in compliance with the License.
 | 
					 | 
				
			||||||
# You may obtain a copy of the License at
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
# Unless required by applicable law or agreed to in writing, software
 | 
					 | 
				
			||||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
					 | 
				
			||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					 | 
				
			||||||
# limitations under the License.
 | 
					 | 
				
			||||||
#
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch.nn.parallel
 | 
					 | 
				
			||||||
import torch.distributed as dist
 | 
					 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ipex_llm.utils.common import invalidInputError
 | 
					 | 
				
			||||||
from ipex_llm.transformers import init_pipeline_parallel, ModelRunner
 | 
					 | 
				
			||||||
import oneccl_bindings_for_pytorch
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from transformers.utils import logging
 | 
					 | 
				
			||||||
logger = logging.get_logger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
init_pipeline_parallel()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
my_rank = dist.get_rank()
 | 
					 | 
				
			||||||
my_size = dist.get_world_size()
 | 
					 | 
				
			||||||
device = f"xpu:{my_rank}"
 | 
					 | 
				
			||||||
logger.info(f"rank: {my_rank}, size: {my_size}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import time
 | 
					 | 
				
			||||||
from transformers import AutoTokenizer
 | 
					 | 
				
			||||||
from fastapi import FastAPI, HTTPException, Request
 | 
					 | 
				
			||||||
from fastapi.responses import StreamingResponse
 | 
					 | 
				
			||||||
from pydantic import BaseModel
 | 
					 | 
				
			||||||
import uvicorn
 | 
					 | 
				
			||||||
import asyncio, uuid
 | 
					 | 
				
			||||||
from typing import Dict, List, Optional, Any, Callable, Union
 | 
					 | 
				
			||||||
import argparse
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class PromptRequest(BaseModel):
 | 
					 | 
				
			||||||
    prompt: str
 | 
					 | 
				
			||||||
    n_predict: Optional[int] = 256
 | 
					 | 
				
			||||||
    req_type: str = 'completion'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from openai.types.chat import ChatCompletionMessageParam
 | 
					 | 
				
			||||||
class ChatCompletionRequest(BaseModel):
 | 
					 | 
				
			||||||
    messages: List[ChatCompletionMessageParam]
 | 
					 | 
				
			||||||
    model: str
 | 
					 | 
				
			||||||
    max_tokens: Optional[int] = None
 | 
					 | 
				
			||||||
    stream: Optional[bool] = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class CompletionRequest(BaseModel):
 | 
					 | 
				
			||||||
    model: str
 | 
					 | 
				
			||||||
    prompt: Union[List[int], List[List[int]], str, List[str]]
 | 
					 | 
				
			||||||
    max_tokens: Optional[int] = None
 | 
					 | 
				
			||||||
    stream: Optional[bool] = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
empty_req = PromptRequest(prompt="", n_predict=0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
app = FastAPI()
 | 
					 | 
				
			||||||
global tokenizer
 | 
					 | 
				
			||||||
global local_model
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
					 | 
				
			||||||
result_dict: Dict[str, str] = {}
 | 
					 | 
				
			||||||
streamer_dict = {}
 | 
					 | 
				
			||||||
local_rank = my_rank
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from openai_protocol import (
 | 
					 | 
				
			||||||
    ChatCompletionResponseStreamChoice,
 | 
					 | 
				
			||||||
    ChatCompletionStreamResponse,
 | 
					 | 
				
			||||||
    ChatCompletionResponseChoice,
 | 
					 | 
				
			||||||
    ChatCompletionResponse,
 | 
					 | 
				
			||||||
    ChatMessage,
 | 
					 | 
				
			||||||
    DeltaMessage,
 | 
					 | 
				
			||||||
    CompletionResponseChoice,
 | 
					 | 
				
			||||||
    CompletionResponse,
 | 
					 | 
				
			||||||
    CompletionResponseStreamChoice,
 | 
					 | 
				
			||||||
    CompletionStreamResponse,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def chat_stream_generator(local_model, delta_text_queue, request_id):
 | 
					 | 
				
			||||||
    model_name = local_model.model_name
 | 
					 | 
				
			||||||
    index = 0
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        if not delta_text_queue.empty():
 | 
					 | 
				
			||||||
            with local_model.dict_lock:
 | 
					 | 
				
			||||||
                remain, delta_text = await delta_text_queue.get()
 | 
					 | 
				
			||||||
            # print(remain)
 | 
					 | 
				
			||||||
            choice_data = ChatCompletionResponseStreamChoice(
 | 
					 | 
				
			||||||
                            index=index,
 | 
					 | 
				
			||||||
                            delta=DeltaMessage(role="assistant", content=delta_text),
 | 
					 | 
				
			||||||
                            logprobs=None,
 | 
					 | 
				
			||||||
                            finish_reason=None)
 | 
					 | 
				
			||||||
            chunk = ChatCompletionStreamResponse( 
 | 
					 | 
				
			||||||
                            id=request_id,
 | 
					 | 
				
			||||||
                            choices=[choice_data],
 | 
					 | 
				
			||||||
                            model=model_name)
 | 
					 | 
				
			||||||
            data = chunk.model_dump_json(exclude_unset=True)
 | 
					 | 
				
			||||||
            yield f"data: {data}\n\n"
 | 
					 | 
				
			||||||
            index = index + 1
 | 
					 | 
				
			||||||
            if remain == 0:
 | 
					 | 
				
			||||||
                choice_data = ChatCompletionResponseStreamChoice(
 | 
					 | 
				
			||||||
                                index=index,
 | 
					 | 
				
			||||||
                                delta=DeltaMessage(role="assistant", content=None),
 | 
					 | 
				
			||||||
                                logprobs=None,
 | 
					 | 
				
			||||||
                                finish_reason="length")
 | 
					 | 
				
			||||||
                chunk = ChatCompletionStreamResponse(
 | 
					 | 
				
			||||||
                                id=request_id,
 | 
					 | 
				
			||||||
                                choices=[choice_data],
 | 
					 | 
				
			||||||
                                model=model_name)
 | 
					 | 
				
			||||||
                data = chunk.model_dump_json(exclude_unset=True)
 | 
					 | 
				
			||||||
                yield f"data: {data}\n\n"
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            await asyncio.sleep(0)
 | 
					 | 
				
			||||||
    local_model.streamer.pop(request_id, None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def completion_stream_generator(local_model, delta_text_queue, request_id):
 | 
					 | 
				
			||||||
    model_name = local_model.model_name
 | 
					 | 
				
			||||||
    index = 0
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        if not delta_text_queue.empty():
 | 
					 | 
				
			||||||
            with local_model.dict_lock:
 | 
					 | 
				
			||||||
                remain, delta_text = await delta_text_queue.get()
 | 
					 | 
				
			||||||
            # print(remain)
 | 
					 | 
				
			||||||
            choice_data = CompletionResponseStreamChoice(
 | 
					 | 
				
			||||||
                            index=index,
 | 
					 | 
				
			||||||
                            text=delta_text,
 | 
					 | 
				
			||||||
                            logprobs=None,
 | 
					 | 
				
			||||||
                            finish_reason=None)
 | 
					 | 
				
			||||||
            chunk = CompletionStreamResponse(
 | 
					 | 
				
			||||||
                            id=request_id,
 | 
					 | 
				
			||||||
                            choices=[choice_data],
 | 
					 | 
				
			||||||
                            model=model_name)
 | 
					 | 
				
			||||||
            data = chunk.model_dump_json(exclude_unset=True)
 | 
					 | 
				
			||||||
            yield f"data: {data}\n\n"
 | 
					 | 
				
			||||||
            index = index + 1
 | 
					 | 
				
			||||||
            if remain == 0:
 | 
					 | 
				
			||||||
                choice_data = CompletionResponseStreamChoice(
 | 
					 | 
				
			||||||
                                index=index,
 | 
					 | 
				
			||||||
                                text="",
 | 
					 | 
				
			||||||
                                logprobs=None,
 | 
					 | 
				
			||||||
                                finish_reason="length")
 | 
					 | 
				
			||||||
                chunk = CompletionStreamResponse(
 | 
					 | 
				
			||||||
                                id=request_id,
 | 
					 | 
				
			||||||
                                choices=[choice_data],
 | 
					 | 
				
			||||||
                                model=model_name)
 | 
					 | 
				
			||||||
                data = chunk.model_dump_json(exclude_unset=True)
 | 
					 | 
				
			||||||
                yield f"data: {data}\n\n"
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            await asyncio.sleep(0)
 | 
					 | 
				
			||||||
    local_model.streamer.pop(request_id, None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def generator(local_model, delta_text_queue, request_id):
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        if not delta_text_queue.empty():
 | 
					 | 
				
			||||||
            with local_model.dict_lock:
 | 
					 | 
				
			||||||
                remain, delta_text = await delta_text_queue.get()
 | 
					 | 
				
			||||||
            yield delta_text
 | 
					 | 
				
			||||||
            if remain == 0:
 | 
					 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            await asyncio.sleep(0)
 | 
					 | 
				
			||||||
    local_model.streamer.pop(request_id, None)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@app.post("/generate/")
 | 
					 | 
				
			||||||
async def generate(prompt_request: PromptRequest):
 | 
					 | 
				
			||||||
    request_id = str(uuid.uuid4())
 | 
					 | 
				
			||||||
    await local_model.waiting_requests.put((request_id, prompt_request))
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        await asyncio.sleep(0)
 | 
					 | 
				
			||||||
        cur_streamer = local_model.streamer.get(request_id, None)
 | 
					 | 
				
			||||||
        if cur_streamer is not None:
 | 
					 | 
				
			||||||
            output_str = []
 | 
					 | 
				
			||||||
            async for item in generator(local_model, cur_streamer, request_id):
 | 
					 | 
				
			||||||
                output_str.append(item)
 | 
					 | 
				
			||||||
            return request_id, "".join(output_str)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def generate_stream(prompt_request: PromptRequest):
 | 
					 | 
				
			||||||
    request_id = str(uuid.uuid4()) + "stream"
 | 
					 | 
				
			||||||
    await local_model.waiting_requests.put((request_id, prompt_request))
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        await asyncio.sleep(0)
 | 
					 | 
				
			||||||
        cur_streamer = local_model.streamer.get(request_id, None)
 | 
					 | 
				
			||||||
        if cur_streamer is not None:
 | 
					 | 
				
			||||||
            if prompt_request.req_type == 'completion':
 | 
					 | 
				
			||||||
                cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
 | 
					 | 
				
			||||||
            elif prompt_request.req_type == 'chat':
 | 
					 | 
				
			||||||
                cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                invalidInputError(False, "Invalid Request Type.")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            return request_id, StreamingResponse(
 | 
					 | 
				
			||||||
                content=cur_generator, media_type="text/event-stream"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@app.post("/generate_stream/")
 | 
					 | 
				
			||||||
async def generate_stream_api(prompt_request: PromptRequest):
 | 
					 | 
				
			||||||
    request_id, result = await generate_stream(prompt_request)
 | 
					 | 
				
			||||||
    return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
DEFAULT_SYSTEM_PROMPT = """\
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_prompt(messages) -> str:
 | 
					 | 
				
			||||||
    prompt = ""
 | 
					 | 
				
			||||||
    for msg in messages:
 | 
					 | 
				
			||||||
        role = msg["role"]
 | 
					 | 
				
			||||||
        content = msg["content"]
 | 
					 | 
				
			||||||
        if role == "system":
 | 
					 | 
				
			||||||
            prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
 | 
					 | 
				
			||||||
        elif role == "user":
 | 
					 | 
				
			||||||
            prompt += f"[INST] {content} [/INST] "
 | 
					 | 
				
			||||||
        elif role == "assistant":
 | 
					 | 
				
			||||||
            prompt += f"{content} "
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise ValueError(f"Unknown role: {role}")
 | 
					 | 
				
			||||||
    return prompt.strip()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@app.post("/v1/chat/completions")
 | 
					 | 
				
			||||||
async def create_chat_completion(request: ChatCompletionRequest):
 | 
					 | 
				
			||||||
    model_name = local_model.model_name
 | 
					 | 
				
			||||||
    if request.max_tokens is None:
 | 
					 | 
				
			||||||
        n_predict = 256
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        n_predict = request.max_tokens
 | 
					 | 
				
			||||||
    prompt_request = PromptRequest(
 | 
					 | 
				
			||||||
        prompt=get_prompt(request.messages),
 | 
					 | 
				
			||||||
        n_predict=n_predict,
 | 
					 | 
				
			||||||
        req_type="chat"
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    if request.stream:
 | 
					 | 
				
			||||||
        request_id, result = await generate_stream(prompt_request)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        request_id, result = await generate(prompt_request)
 | 
					 | 
				
			||||||
        choice_data = ChatCompletionResponseChoice(
 | 
					 | 
				
			||||||
                        index=0,
 | 
					 | 
				
			||||||
                        message=ChatMessage(role="assistant", content=result),
 | 
					 | 
				
			||||||
                        logprobs=None,
 | 
					 | 
				
			||||||
                        finish_reason="length")
 | 
					 | 
				
			||||||
        result = ChatCompletionResponse( 
 | 
					 | 
				
			||||||
                        id=request_id,
 | 
					 | 
				
			||||||
                        choices=[choice_data],
 | 
					 | 
				
			||||||
                        model=model_name)
 | 
					 | 
				
			||||||
    return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@app.post("/v1/completions")
 | 
					 | 
				
			||||||
async def create_completion(request: CompletionRequest):
 | 
					 | 
				
			||||||
    model_name = local_model.model_name
 | 
					 | 
				
			||||||
    if request.max_tokens is None:
 | 
					 | 
				
			||||||
        n_predict = 256
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        n_predict = request.max_tokens
 | 
					 | 
				
			||||||
    prompt_request = PromptRequest(
 | 
					 | 
				
			||||||
        prompt=request.prompt,
 | 
					 | 
				
			||||||
        n_predict=n_predict,
 | 
					 | 
				
			||||||
        req_type="completion"
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    if request.stream:
 | 
					 | 
				
			||||||
        request_id, result = await generate_stream(prompt_request)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        request_id, result = await generate(prompt_request)
 | 
					 | 
				
			||||||
        choice_data = CompletionResponseChoice(
 | 
					 | 
				
			||||||
                            index=0,
 | 
					 | 
				
			||||||
                            text=result,
 | 
					 | 
				
			||||||
                            logprobs=None,
 | 
					 | 
				
			||||||
                            finish_reason="length")
 | 
					 | 
				
			||||||
        result = CompletionResponse(
 | 
					 | 
				
			||||||
                            id=request_id,
 | 
					 | 
				
			||||||
                            choices=[choice_data],
 | 
					 | 
				
			||||||
                            model=model_name)
 | 
					 | 
				
			||||||
    return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def process_requests(local_model, result_dict):
 | 
					 | 
				
			||||||
    while True:
 | 
					 | 
				
			||||||
        await asyncio.sleep(0)
 | 
					 | 
				
			||||||
        await local_model.process_step(tokenizer, result_dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@app.on_event("startup")
 | 
					 | 
				
			||||||
async def startup_event():
 | 
					 | 
				
			||||||
    asyncio.create_task(process_requests(local_model, result_dict))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
async def main():
 | 
					 | 
				
			||||||
    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
 | 
					 | 
				
			||||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
					 | 
				
			||||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
 | 
					 | 
				
			||||||
                             ', or the path to the huggingface checkpoint folder')
 | 
					 | 
				
			||||||
    parser.add_argument('--low-bit', type=str, default='sym_int4',
 | 
					 | 
				
			||||||
                        help='The quantization type the model will convert to.')
 | 
					 | 
				
			||||||
    parser.add_argument('--port', type=int, default=8000,
 | 
					 | 
				
			||||||
                        help='The port number on which the server will run.')
 | 
					 | 
				
			||||||
    parser.add_argument('--max-num-seqs', type=int, default=8,
 | 
					 | 
				
			||||||
                        help='Max num sequences in a batch.')
 | 
					 | 
				
			||||||
    parser.add_argument('--max-prefilled-seqs', type=int, default=0,
 | 
					 | 
				
			||||||
                        help='Max num sequences in a batch during prefilling.')
 | 
					 | 
				
			||||||
    
 | 
					 | 
				
			||||||
    args = parser.parse_args()
 | 
					 | 
				
			||||||
    model_path = args.repo_id_or_model_path
 | 
					 | 
				
			||||||
    low_bit = args.low_bit
 | 
					 | 
				
			||||||
    max_num_seqs = args.max_num_seqs
 | 
					 | 
				
			||||||
    max_prefilled_seqs = args.max_prefilled_seqs
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # serialize model initialization so that we do not run out of CPU memory
 | 
					 | 
				
			||||||
    for i in range(my_size):
 | 
					 | 
				
			||||||
        if my_rank == i:
 | 
					 | 
				
			||||||
            logger.info("start model initialization")
 | 
					 | 
				
			||||||
            global local_model
 | 
					 | 
				
			||||||
            local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
 | 
					 | 
				
			||||||
            logger.info("model initialized")
 | 
					 | 
				
			||||||
        dist.barrier()
 | 
					 | 
				
			||||||
    # Load tokenizer
 | 
					 | 
				
			||||||
    global tokenizer
 | 
					 | 
				
			||||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
					 | 
				
			||||||
    if tokenizer.pad_token is None:
 | 
					 | 
				
			||||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if local_rank == 0:
 | 
					 | 
				
			||||||
        config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
 | 
					 | 
				
			||||||
        server = uvicorn.Server(config)
 | 
					 | 
				
			||||||
        await server.serve()
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        while True:
 | 
					 | 
				
			||||||
            await asyncio.sleep(0)
 | 
					 | 
				
			||||||
            await local_model.process_step(tokenizer, result_dict)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
if __name__ == "__main__":
 | 
					 | 
				
			||||||
    asyncio.run(main())
 | 
					 | 
				
			||||||
| 
						 | 
					@ -50,7 +50,14 @@ pip install transformers==4.40.0
 | 
				
			||||||
pip install trl==0.8.1
 | 
					pip install trl==0.8.1
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### 2. Run pipeline parallel serving on multiple GPUs
 | 
					### 2-1. Run ipex-llm serving on one GPU card 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					# Need to set NUM_GPUS=1 and MODEL_PATH in run.sh first
 | 
				
			||||||
 | 
					bash run.sh
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### 2-2. Run pipeline parallel serving on multiple GPUs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
# Need to set MODEL_PATH in run.sh first
 | 
					# Need to set MODEL_PATH in run.sh first
 | 
				
			||||||
| 
						 | 
					@ -76,7 +83,7 @@ export http_proxy=
 | 
				
			||||||
export https_proxy=
 | 
					export https_proxy=
 | 
				
			||||||
 | 
					
 | 
				
			||||||
curl -X 'POST' \
 | 
					curl -X 'POST' \
 | 
				
			||||||
  'http://127.0.0.1:8000/generate/' \
 | 
					  'http://127.0.0.1:8000/generate' \
 | 
				
			||||||
  -H 'accept: application/json' \
 | 
					  -H 'accept: application/json' \
 | 
				
			||||||
  -H 'Content-Type: application/json' \
 | 
					  -H 'Content-Type: application/json' \
 | 
				
			||||||
  -d '{
 | 
					  -d '{
 | 
				
			||||||
| 
						 | 
					@ -99,7 +106,7 @@ Please change the test url accordingly.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
# set t/c to the number of concurrencies to test full throughput.
 | 
					# set t/c to the number of concurrencies to test full throughput.
 | 
				
			||||||
wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate/ --timeout 1m
 | 
					wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate --timeout 1m
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## 5. Using the `benchmark.py` Script
 | 
					## 5. Using the `benchmark.py` Script
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,78 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch.distributed as dist
 | 
				
			||||||
 | 
					from ipex_llm.transformers import init_pipeline_parallel, PPModelWorker
 | 
				
			||||||
 | 
					from ipex_llm.serving.fastapi import FastApp
 | 
				
			||||||
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer
 | 
				
			||||||
 | 
					import uvicorn
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from typing import Dict
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					init_pipeline_parallel()
 | 
				
			||||||
 | 
					my_rank = dist.get_rank()
 | 
				
			||||||
 | 
					my_size = dist.get_world_size()
 | 
				
			||||||
 | 
					device = f"xpu:{my_rank}"
 | 
				
			||||||
 | 
					logger.info(f"rank: {my_rank}, size: {my_size}")
 | 
				
			||||||
 | 
					result_dict: Dict[str, str] = {}
 | 
				
			||||||
 | 
					local_rank = my_rank
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def main():
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging Pipeline-Parallel')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--low-bit', type=str, default='sym_int4',
 | 
				
			||||||
 | 
					                        help='The quantization type the model will convert to.')
 | 
				
			||||||
 | 
					    parser.add_argument('--port', type=int, default=8000,
 | 
				
			||||||
 | 
					                        help='The port number on which the server will run.')
 | 
				
			||||||
 | 
					    parser.add_argument('--max-num-seqs', type=int, default=8,
 | 
				
			||||||
 | 
					                        help='Max num sequences in a batch.')
 | 
				
			||||||
 | 
					    parser.add_argument('--max-prefilled-seqs', type=int, default=0,
 | 
				
			||||||
 | 
					                        help='Max num sequences in a batch during prefilling.')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					    low_bit = args.low_bit
 | 
				
			||||||
 | 
					    max_num_seqs = args.max_num_seqs
 | 
				
			||||||
 | 
					    max_prefilled_seqs = args.max_prefilled_seqs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # serialize model initialization so that we do not run out of CPU memory
 | 
				
			||||||
 | 
					    for i in range(my_size):
 | 
				
			||||||
 | 
					        if my_rank == i:
 | 
				
			||||||
 | 
					            logger.info("start model initialization")
 | 
				
			||||||
 | 
					            local_model = PPModelWorker(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
 | 
				
			||||||
 | 
					            logger.info("model initialized")
 | 
				
			||||||
 | 
					        dist.barrier()
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
				
			||||||
 | 
					    if tokenizer.pad_token is None:
 | 
				
			||||||
 | 
					        tokenizer.pad_token = tokenizer.eos_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    myapp = FastApp(local_model, tokenizer)
 | 
				
			||||||
 | 
					    if local_rank == 0:
 | 
				
			||||||
 | 
					        config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
 | 
				
			||||||
 | 
					        server = uvicorn.Server(config)
 | 
				
			||||||
 | 
					        await server.serve()
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        while True:
 | 
				
			||||||
 | 
					            await asyncio.sleep(0)
 | 
				
			||||||
 | 
					            await local_model.process_step(tokenizer, result_dict)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    asyncio.run(main())
 | 
				
			||||||
| 
						 | 
					@ -40,4 +40,9 @@ export LOW_BIT="fp8"
 | 
				
			||||||
export MAX_NUM_SEQS="4"
 | 
					export MAX_NUM_SEQS="4"
 | 
				
			||||||
export MAX_PREFILLED_SEQS=0
 | 
					export MAX_PREFILLED_SEQS=0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT --max-num-seqs $MAX_NUM_SEQS --max-prefilled-seqs $MAX_PREFILLED_SEQS
 | 
					if [[ $NUM_GPUS -eq 1 ]]; then
 | 
				
			||||||
 | 
					    export ZE_AFFINITY_MASK=0
 | 
				
			||||||
 | 
					    python serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT
 | 
				
			||||||
 | 
					else
 | 
				
			||||||
 | 
					    CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT --max-num-seqs $MAX_NUM_SEQS --max-prefilled-seqs $MAX_PREFILLED_SEQS
 | 
				
			||||||
 | 
					fi
 | 
				
			||||||
							
								
								
									
										53
									
								
								python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,53 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					from transformers import AutoTokenizer
 | 
				
			||||||
 | 
					import uvicorn
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					from ipex_llm.serving.fastapi import FastApp
 | 
				
			||||||
 | 
					from ipex_llm.serving.fastapi import ModelWorker
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def main():
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging ipex-llm')
 | 
				
			||||||
 | 
					    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
				
			||||||
 | 
					                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
 | 
				
			||||||
 | 
					                             ', or the path to the huggingface checkpoint folder')
 | 
				
			||||||
 | 
					    parser.add_argument('--low-bit', type=str, default='sym_int4',
 | 
				
			||||||
 | 
					                        help='The quantization type the model will convert to.')
 | 
				
			||||||
 | 
					    parser.add_argument('--port', type=int, default=8000,
 | 
				
			||||||
 | 
					                        help='The port number on which the server will run.')
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					    model_path = args.repo_id_or_model_path
 | 
				
			||||||
 | 
					    low_bit = args.low_bit
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    local_model = ModelWorker(model_path, low_bit)
 | 
				
			||||||
 | 
					    # Load tokenizer
 | 
				
			||||||
 | 
					    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
				
			||||||
 | 
					    if tokenizer.pad_token is None:
 | 
				
			||||||
 | 
					        tokenizer.pad_token = tokenizer.eos_token
 | 
				
			||||||
 | 
					    myapp = FastApp(local_model, tokenizer)
 | 
				
			||||||
 | 
					    config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
 | 
				
			||||||
 | 
					    server = uvicorn.Server(config)
 | 
				
			||||||
 | 
					    await server.serve()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    asyncio.run(main())
 | 
				
			||||||
							
								
								
									
										18
									
								
								python/llm/src/ipex_llm/serving/fastapi/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								python/llm/src/ipex_llm/serving/fastapi/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,18 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .api_server import FastApp
 | 
				
			||||||
 | 
					from .model_worker import ModelWorker
 | 
				
			||||||
							
								
								
									
										315
									
								
								python/llm/src/ipex_llm/serving/fastapi/api_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										315
									
								
								python/llm/src/ipex_llm/serving/fastapi/api_server.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,315 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					from fastapi import FastAPI
 | 
				
			||||||
 | 
					from fastapi.responses import StreamingResponse
 | 
				
			||||||
 | 
					from openai.types.chat import ChatCompletionMessageParam
 | 
				
			||||||
 | 
					from pydantic import BaseModel
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					import uuid
 | 
				
			||||||
 | 
					from typing import List, Optional, Union, Dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					result_dict: Dict[str, str] = {}
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PromptRequest(BaseModel):
 | 
				
			||||||
 | 
					    prompt: str
 | 
				
			||||||
 | 
					    n_predict: Optional[int] = 256
 | 
				
			||||||
 | 
					    req_type: str = 'completion'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChatCompletionRequest(BaseModel):
 | 
				
			||||||
 | 
					    messages: List[ChatCompletionMessageParam]
 | 
				
			||||||
 | 
					    model: str
 | 
				
			||||||
 | 
					    max_tokens: Optional[int] = None
 | 
				
			||||||
 | 
					    stream: Optional[bool] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CompletionRequest(BaseModel):
 | 
				
			||||||
 | 
					    model: str
 | 
				
			||||||
 | 
					    prompt: Union[List[int], List[List[int]], str, List[str]]
 | 
				
			||||||
 | 
					    max_tokens: Optional[int] = None
 | 
				
			||||||
 | 
					    stream: Optional[bool] = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					app = FastAPI()
 | 
				
			||||||
 | 
					global tokenizer
 | 
				
			||||||
 | 
					global local_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FastApp():
 | 
				
			||||||
 | 
					    def __init__(self, model, mytokenizer):
 | 
				
			||||||
 | 
					        global tokenizer
 | 
				
			||||||
 | 
					        global local_model
 | 
				
			||||||
 | 
					        local_model = model
 | 
				
			||||||
 | 
					        tokenizer = mytokenizer
 | 
				
			||||||
 | 
					        self.app = app
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .openai_protocol import (
 | 
				
			||||||
 | 
					    ChatCompletionResponseStreamChoice,
 | 
				
			||||||
 | 
					    ChatCompletionStreamResponse,
 | 
				
			||||||
 | 
					    ChatCompletionResponseChoice,
 | 
				
			||||||
 | 
					    ChatCompletionResponse,
 | 
				
			||||||
 | 
					    ChatMessage,
 | 
				
			||||||
 | 
					    DeltaMessage,
 | 
				
			||||||
 | 
					    CompletionResponseChoice,
 | 
				
			||||||
 | 
					    CompletionResponse,
 | 
				
			||||||
 | 
					    CompletionResponseStreamChoice,
 | 
				
			||||||
 | 
					    CompletionStreamResponse,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_queue_next_token(delta_text_queue):
 | 
				
			||||||
 | 
					    timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
 | 
				
			||||||
 | 
					    delta_text = delta_text_queue.text_queue.get(timeout=timeout)
 | 
				
			||||||
 | 
					    if delta_text is None:
 | 
				
			||||||
 | 
					        remain = 0
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        remain = 1
 | 
				
			||||||
 | 
					    return delta_text, remain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def chat_stream_generator(local_model, delta_text_queue, request_id):
 | 
				
			||||||
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
 | 
					    index = 0
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        if not hasattr(delta_text_queue, 'empty'):
 | 
				
			||||||
 | 
					            delta_text, remain = get_queue_next_token(delta_text_queue)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if not delta_text_queue.empty():
 | 
				
			||||||
 | 
					                with local_model.dict_lock:
 | 
				
			||||||
 | 
					                    remain, delta_text = await delta_text_queue.get()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                await asyncio.sleep(0)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					        if remain == 0 and delta_text is not None or remain != 0:
 | 
				
			||||||
 | 
					            choice_data = ChatCompletionResponseStreamChoice(
 | 
				
			||||||
 | 
					                index=index,
 | 
				
			||||||
 | 
					                delta=DeltaMessage(role="assistant", content=delta_text),
 | 
				
			||||||
 | 
					                logprobs=None,
 | 
				
			||||||
 | 
					                finish_reason=None)
 | 
				
			||||||
 | 
					            chunk = ChatCompletionStreamResponse(
 | 
				
			||||||
 | 
					                id=request_id,
 | 
				
			||||||
 | 
					                choices=[choice_data],
 | 
				
			||||||
 | 
					                model=model_name)
 | 
				
			||||||
 | 
					            data = chunk.model_dump_json(exclude_unset=True)
 | 
				
			||||||
 | 
					            yield f"data: {data}\n\n"
 | 
				
			||||||
 | 
					            index = index + 1
 | 
				
			||||||
 | 
					        if remain == 0:
 | 
				
			||||||
 | 
					            choice_data = ChatCompletionResponseStreamChoice(
 | 
				
			||||||
 | 
					                index=index,
 | 
				
			||||||
 | 
					                delta=DeltaMessage(role="assistant", content=None),
 | 
				
			||||||
 | 
					                logprobs=None,
 | 
				
			||||||
 | 
					                finish_reason="length")
 | 
				
			||||||
 | 
					            chunk = ChatCompletionStreamResponse(
 | 
				
			||||||
 | 
					                id=request_id,
 | 
				
			||||||
 | 
					                choices=[choice_data],
 | 
				
			||||||
 | 
					                model=model_name)
 | 
				
			||||||
 | 
					            data = chunk.model_dump_json(exclude_unset=True)
 | 
				
			||||||
 | 
					            yield f"data: {data}\n\n"
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					    local_model.streamer.pop(request_id, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def completion_stream_generator(local_model, delta_text_queue, request_id):
 | 
				
			||||||
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
 | 
					    index = 0
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        if not hasattr(delta_text_queue, 'empty'):
 | 
				
			||||||
 | 
					            delta_text, remain = get_queue_next_token(delta_text_queue)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if not delta_text_queue.empty():
 | 
				
			||||||
 | 
					                with local_model.dict_lock:
 | 
				
			||||||
 | 
					                    remain, delta_text = await delta_text_queue.get()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                await asyncio.sleep(0)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					        if remain == 0 and delta_text is not None or remain != 0:
 | 
				
			||||||
 | 
					            choice_data = CompletionResponseStreamChoice(
 | 
				
			||||||
 | 
					                index=index,
 | 
				
			||||||
 | 
					                text=delta_text,
 | 
				
			||||||
 | 
					                logprobs=None,
 | 
				
			||||||
 | 
					                finish_reason=None)
 | 
				
			||||||
 | 
					            chunk = CompletionStreamResponse(
 | 
				
			||||||
 | 
					                id=request_id,
 | 
				
			||||||
 | 
					                choices=[choice_data],
 | 
				
			||||||
 | 
					                model=model_name)
 | 
				
			||||||
 | 
					            data = chunk.model_dump_json(exclude_unset=True)
 | 
				
			||||||
 | 
					            yield f"data: {data}\n\n"
 | 
				
			||||||
 | 
					            index = index + 1
 | 
				
			||||||
 | 
					        if remain == 0:
 | 
				
			||||||
 | 
					            choice_data = CompletionResponseStreamChoice(
 | 
				
			||||||
 | 
					                index=index,
 | 
				
			||||||
 | 
					                text="",
 | 
				
			||||||
 | 
					                logprobs=None,
 | 
				
			||||||
 | 
					                finish_reason="length")
 | 
				
			||||||
 | 
					            chunk = CompletionStreamResponse(
 | 
				
			||||||
 | 
					                id=request_id,
 | 
				
			||||||
 | 
					                choices=[choice_data],
 | 
				
			||||||
 | 
					                model=model_name)
 | 
				
			||||||
 | 
					            data = chunk.model_dump_json(exclude_unset=True)
 | 
				
			||||||
 | 
					            yield f"data: {data}\n\n"
 | 
				
			||||||
 | 
					            break
 | 
				
			||||||
 | 
					    local_model.streamer.pop(request_id, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def generator(local_model, delta_text_queue, request_id):
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        if not hasattr(delta_text_queue, 'empty'):
 | 
				
			||||||
 | 
					            delta_text, remain = get_queue_next_token(delta_text_queue)
 | 
				
			||||||
 | 
					            if delta_text is None:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                yield delta_text
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if not delta_text_queue.empty():
 | 
				
			||||||
 | 
					                with local_model.dict_lock:
 | 
				
			||||||
 | 
					                    remain, delta_text = await delta_text_queue.get()
 | 
				
			||||||
 | 
					                yield delta_text
 | 
				
			||||||
 | 
					                if remain == 0:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                await asyncio.sleep(0)
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					    local_model.streamer.pop(request_id, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/generate")
 | 
				
			||||||
 | 
					async def generate(prompt_request: PromptRequest):
 | 
				
			||||||
 | 
					    request_id = str(uuid.uuid4())
 | 
				
			||||||
 | 
					    await local_model.waiting_requests.put((request_id, prompt_request))
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        cur_streamer = local_model.streamer.get(request_id, None)
 | 
				
			||||||
 | 
					        if cur_streamer is not None:
 | 
				
			||||||
 | 
					            output_str = []
 | 
				
			||||||
 | 
					            async for item in generator(local_model, cur_streamer, request_id):
 | 
				
			||||||
 | 
					                output_str.append(item)
 | 
				
			||||||
 | 
					            return request_id, "".join(output_str)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/generate_stream")
 | 
				
			||||||
 | 
					async def generate_stream_api(prompt_request: PromptRequest):
 | 
				
			||||||
 | 
					    request_id, result = await generate_stream(prompt_request)
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def generate_stream(prompt_request: PromptRequest):
 | 
				
			||||||
 | 
					    request_id = str(uuid.uuid4()) + "stream"
 | 
				
			||||||
 | 
					    await local_model.waiting_requests.put((request_id, prompt_request))
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        cur_streamer = local_model.streamer.get(request_id, None)
 | 
				
			||||||
 | 
					        if cur_streamer is not None:
 | 
				
			||||||
 | 
					            if prompt_request.req_type == 'completion':
 | 
				
			||||||
 | 
					                cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
 | 
				
			||||||
 | 
					            elif prompt_request.req_type == 'chat':
 | 
				
			||||||
 | 
					                cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                invalidInputError(False, "Invalid Request Type.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return request_id, StreamingResponse(
 | 
				
			||||||
 | 
					                content=cur_generator, media_type="text/event-stream"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_prompt(messages) -> str:
 | 
				
			||||||
 | 
					    prompt = ""
 | 
				
			||||||
 | 
					    for msg in messages:
 | 
				
			||||||
 | 
					        role = msg["role"]
 | 
				
			||||||
 | 
					        content = msg["content"]
 | 
				
			||||||
 | 
					        if role == "system":
 | 
				
			||||||
 | 
					            prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
 | 
				
			||||||
 | 
					        elif role == "user":
 | 
				
			||||||
 | 
					            prompt += f"[INST] {content} [/INST] "
 | 
				
			||||||
 | 
					        elif role == "assistant":
 | 
				
			||||||
 | 
					            prompt += f"{content} "
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            invalidInputError(False, f"Unknown role: {role}")
 | 
				
			||||||
 | 
					    return prompt.strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/v1/chat/completions")
 | 
				
			||||||
 | 
					async def create_chat_completion(request: ChatCompletionRequest):
 | 
				
			||||||
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
 | 
					    if request.max_tokens is None:
 | 
				
			||||||
 | 
					        n_predict = 256
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        n_predict = request.max_tokens
 | 
				
			||||||
 | 
					    prompt_request = PromptRequest(
 | 
				
			||||||
 | 
					        prompt=get_prompt(request.messages),
 | 
				
			||||||
 | 
					        n_predict=n_predict,
 | 
				
			||||||
 | 
					        req_type="chat"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if request.stream:
 | 
				
			||||||
 | 
					        request_id, result = await generate_stream(prompt_request)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        request_id, result = await generate(prompt_request)
 | 
				
			||||||
 | 
					        choice_data = ChatCompletionResponseChoice(
 | 
				
			||||||
 | 
					            index=0,
 | 
				
			||||||
 | 
					            message=ChatMessage(role="assistant", content=result),
 | 
				
			||||||
 | 
					            logprobs=None,
 | 
				
			||||||
 | 
					            finish_reason="length")
 | 
				
			||||||
 | 
					        result = ChatCompletionResponse(
 | 
				
			||||||
 | 
					            id=request_id,
 | 
				
			||||||
 | 
					            choices=[choice_data],
 | 
				
			||||||
 | 
					            model=model_name)
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.post("/v1/completions")
 | 
				
			||||||
 | 
					async def create_completion(request: CompletionRequest):
 | 
				
			||||||
 | 
					    model_name = local_model.model_name
 | 
				
			||||||
 | 
					    if request.max_tokens is None:
 | 
				
			||||||
 | 
					        n_predict = 256
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        n_predict = request.max_tokens
 | 
				
			||||||
 | 
					    prompt_request = PromptRequest(
 | 
				
			||||||
 | 
					        prompt=request.prompt,
 | 
				
			||||||
 | 
					        n_predict=n_predict,
 | 
				
			||||||
 | 
					        req_type="completion"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    if request.stream:
 | 
				
			||||||
 | 
					        request_id, result = await generate_stream(prompt_request)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        request_id, result = await generate(prompt_request)
 | 
				
			||||||
 | 
					        choice_data = CompletionResponseChoice(
 | 
				
			||||||
 | 
					            index=0,
 | 
				
			||||||
 | 
					            text=result,
 | 
				
			||||||
 | 
					            logprobs=None,
 | 
				
			||||||
 | 
					            finish_reason="length")
 | 
				
			||||||
 | 
					        result = CompletionResponse(
 | 
				
			||||||
 | 
					            id=request_id,
 | 
				
			||||||
 | 
					            choices=[choice_data],
 | 
				
			||||||
 | 
					            model=model_name)
 | 
				
			||||||
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@app.on_event("startup")
 | 
				
			||||||
 | 
					async def startup_event():
 | 
				
			||||||
 | 
					    asyncio.create_task(process_requests(local_model, result_dict))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def process_requests(local_model, result_dict):
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        await asyncio.sleep(0)
 | 
				
			||||||
 | 
					        await local_model.process_step(tokenizer, result_dict)
 | 
				
			||||||
							
								
								
									
										82
									
								
								python/llm/src/ipex_llm/serving/fastapi/model_worker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								python/llm/src/ipex_llm/serving/fastapi/model_worker.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,82 @@
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright 2016 The BigDL Authors.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from transformers.utils import logging
 | 
				
			||||||
 | 
					import time
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from transformers import TextIteratorStreamer
 | 
				
			||||||
 | 
					logger = logging.get_logger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ModelWorker:
 | 
				
			||||||
 | 
					    def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
 | 
				
			||||||
 | 
					        self.dtype = torch_dtype
 | 
				
			||||||
 | 
					        start = time.perf_counter()
 | 
				
			||||||
 | 
					        model = self.load_model(checkpoint, low_bit)
 | 
				
			||||||
 | 
					        from ipex_llm.utils.benchmark_util import BenchmarkWrapper
 | 
				
			||||||
 | 
					        self.model = BenchmarkWrapper(model, do_print=True)
 | 
				
			||||||
 | 
					        end = time.perf_counter()
 | 
				
			||||||
 | 
					        logger.info(f"Time to load weights: {end - start:.2f}s")
 | 
				
			||||||
 | 
					        self.waiting_requests = asyncio.Queue()
 | 
				
			||||||
 | 
					        self.streamer = {}
 | 
				
			||||||
 | 
					        self.model_name = checkpoint
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def load_model(self, model_path, low_bit='sym_int4'):
 | 
				
			||||||
 | 
					        from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            model = AutoModelForCausalLM.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                                         load_in_low_bit=low_bit,
 | 
				
			||||||
 | 
					                                                         torch_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                                         optimize_model=True,
 | 
				
			||||||
 | 
					                                                         trust_remote_code=True,
 | 
				
			||||||
 | 
					                                                         use_cache=True,)
 | 
				
			||||||
 | 
					        except:
 | 
				
			||||||
 | 
					            model = AutoModel.from_pretrained(model_path,
 | 
				
			||||||
 | 
					                                              load_in_low_bit=low_bit,
 | 
				
			||||||
 | 
					                                              torch_dtype=self.dtype,
 | 
				
			||||||
 | 
					                                              optimize_model=True,
 | 
				
			||||||
 | 
					                                              trust_remote_code=True,
 | 
				
			||||||
 | 
					                                              use_cache=True,)
 | 
				
			||||||
 | 
					        model = model.eval().to("xpu")
 | 
				
			||||||
 | 
					        return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def add_request(self, tokenizer):
 | 
				
			||||||
 | 
					        if self.waiting_requests.empty():
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        tmp_result = await self.waiting_requests.get()
 | 
				
			||||||
 | 
					        request_id, prompt_request = tmp_result
 | 
				
			||||||
 | 
					        plain_texts = prompt_request.prompt
 | 
				
			||||||
 | 
					        inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
 | 
				
			||||||
 | 
					        input_ids = inputs.input_ids.to('xpu')
 | 
				
			||||||
 | 
					        max_tokens = prompt_request.n_predict
 | 
				
			||||||
 | 
					        return input_ids, max_tokens, request_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @torch.no_grad()
 | 
				
			||||||
 | 
					    async def process_step(self, tokenizer, result_dict):
 | 
				
			||||||
 | 
					        if not self.waiting_requests.empty():
 | 
				
			||||||
 | 
					            input_ids, max_tokens, request_id = await self.add_request(tokenizer)
 | 
				
			||||||
 | 
					            self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            def model_generate():
 | 
				
			||||||
 | 
					                self.model.generate(input_ids,
 | 
				
			||||||
 | 
					                                    streamer=self.streamer[request_id], max_new_tokens=max_tokens)
 | 
				
			||||||
 | 
					                torch.xpu.empty_cache()
 | 
				
			||||||
 | 
					                torch.xpu.synchronize()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            from threading import Thread
 | 
				
			||||||
 | 
					            t1 = Thread(target=model_generate)
 | 
				
			||||||
 | 
					            t1.start()
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# Adapted from
 | 
					# Adapted from
 | 
				
			||||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
 | 
					# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
from typing import Dict, List, Literal, Optional, Union
 | 
					from typing import Dict, List, Literal, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,11 +23,14 @@ import torch
 | 
				
			||||||
from openai.types.chat import ChatCompletionMessageParam
 | 
					from openai.types.chat import ChatCompletionMessageParam
 | 
				
			||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
 | 
					from pydantic import BaseModel, ConfigDict, Field, model_validator
 | 
				
			||||||
from typing_extensions import Annotated
 | 
					from typing_extensions import Annotated
 | 
				
			||||||
 | 
					from ipex_llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# from vllm.sampling_params import SamplingParams
 | 
					# from vllm.sampling_params import SamplingParams
 | 
				
			||||||
def random_uuid() -> str:
 | 
					def random_uuid() -> str:
 | 
				
			||||||
    return str(uuid.uuid4().hex)
 | 
					    return str(uuid.uuid4().hex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OpenAIBaseModel(BaseModel):
 | 
					class OpenAIBaseModel(BaseModel):
 | 
				
			||||||
    # OpenAI API does not allow extra fields
 | 
					    # OpenAI API does not allow extra fields
 | 
				
			||||||
    model_config = ConfigDict(extra="forbid")
 | 
					    model_config = ConfigDict(extra="forbid")
 | 
				
			||||||
| 
						 | 
					@ -127,10 +131,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    add_generation_prompt: Optional[bool] = Field(
 | 
					    add_generation_prompt: Optional[bool] = Field(
 | 
				
			||||||
        default=True,
 | 
					        default=True,
 | 
				
			||||||
        description=
 | 
					        description=(
 | 
				
			||||||
        ("If true, the generation prompt will be added to the chat template. "
 | 
					            "If true, the generation prompt will be added to the chat template. "
 | 
				
			||||||
         "This is a parameter used by chat template in tokenizer config of the "
 | 
					            "This is a parameter used by chat template in tokenizer config of the "
 | 
				
			||||||
         "model."),
 | 
					            "model."),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    include_stop_str_in_output: Optional[bool] = Field(
 | 
					    include_stop_str_in_output: Optional[bool] = Field(
 | 
				
			||||||
        default=False,
 | 
					        default=False,
 | 
				
			||||||
| 
						 | 
					@ -179,9 +183,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
 | 
				
			||||||
            "guided_choice" in data and data["guided_choice"] is not None
 | 
					            "guided_choice" in data and data["guided_choice"] is not None
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
        if guide_count > 1:
 | 
					        if guide_count > 1:
 | 
				
			||||||
            raise ValueError(
 | 
					            invalidInputError(False,
 | 
				
			||||||
                "You can only use one kind of guided decoding "
 | 
					                              "You can only use one kind of guided decoding "
 | 
				
			||||||
                "('guided_json', 'guided_regex' or 'guided_choice').")
 | 
					                              "('guided_json', 'guided_regex' or 'guided_choice').")
 | 
				
			||||||
        return data
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -232,10 +236,10 @@ class CompletionRequest(OpenAIBaseModel):
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    response_format: Optional[ResponseFormat] = Field(
 | 
					    response_format: Optional[ResponseFormat] = Field(
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
        description=
 | 
					        description=(
 | 
				
			||||||
        ("Similar to chat completion, this parameter specifies the format of "
 | 
					            "Similar to chat completion, this parameter specifies the format of "
 | 
				
			||||||
         "output. Only {'type': 'json_object'} or {'type': 'text' } is "
 | 
					            "output. Only {'type': 'json_object'} or {'type': 'text' } is "
 | 
				
			||||||
         "supported."),
 | 
					            "supported."),
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
 | 
					    guided_json: Optional[Union[str, dict, BaseModel]] = Field(
 | 
				
			||||||
        default=None,
 | 
					        default=None,
 | 
				
			||||||
| 
						 | 
					@ -279,9 +283,9 @@ class CompletionRequest(OpenAIBaseModel):
 | 
				
			||||||
            "guided_choice" in data and data["guided_choice"] is not None
 | 
					            "guided_choice" in data and data["guided_choice"] is not None
 | 
				
			||||||
        ])
 | 
					        ])
 | 
				
			||||||
        if guide_count > 1:
 | 
					        if guide_count > 1:
 | 
				
			||||||
            raise ValueError(
 | 
					            invalidInputError(False,
 | 
				
			||||||
                "You can only use one kind of guided decoding "
 | 
					                              "You can only use one kind of guided decoding "
 | 
				
			||||||
                "('guided_json', 'guided_regex' or 'guided_choice').")
 | 
					                              "('guided_json', 'guided_regex' or 'guided_choice').")
 | 
				
			||||||
        return data
 | 
					        return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,4 +22,4 @@ from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
 | 
				
			||||||
        AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
 | 
					        AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
 | 
				
			||||||
        AutoModelForTokenClassification
 | 
					        AutoModelForTokenClassification
 | 
				
			||||||
from .modelling_bigdl import *
 | 
					from .modelling_bigdl import *
 | 
				
			||||||
from .pipeline_parallel import init_pipeline_parallel, ModelRunner
 | 
					from .pipeline_parallel import init_pipeline_parallel, PPModelWorker
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -468,7 +468,7 @@ def make_attention_mask(prompt_lengths):
 | 
				
			||||||
    return attention_mask
 | 
					    return attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ModelRunner:
 | 
					class PPModelWorker:
 | 
				
			||||||
    """Implementation for pipeline parallel multi-stage serving."""
 | 
					    """Implementation for pipeline parallel multi-stage serving."""
 | 
				
			||||||
    def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs,
 | 
					    def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs,
 | 
				
			||||||
                 torch_dtype=torch.float16):
 | 
					                 torch_dtype=torch.float16):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue