Refactor fastapi-serving and add one card serving(#11581)

* init fastapi-serving one card

* mv api code to source

* update worker

* update for style-check

* add worker

* update bash

* update

* update worker name and add readme

* rename update

* rename to fastapi
This commit is contained in:
Wang, Jian4 2024-07-17 11:12:43 +08:00 committed by GitHub
parent 373ccbbb0c
commit 9c15abf825
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 583 additions and 367 deletions

View file

@ -61,7 +61,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
# Download pp_serving
mkdir -p /llm/pp_serving && \
cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-FastAPI/*.py /llm/pp_serving/ && \
cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-Serving/*.py /llm/pp_serving/ && \
# Install related library of benchmarking
pip install pandas omegaconf && \
chmod +x /llm/benchmark.sh && \

View file

@ -1,346 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch.nn.parallel
import torch.distributed as dist
import os
from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers import init_pipeline_parallel, ModelRunner
import oneccl_bindings_for_pytorch
import json
from transformers.utils import logging
logger = logging.get_logger(__name__)
init_pipeline_parallel()
my_rank = dist.get_rank()
my_size = dist.get_world_size()
device = f"xpu:{my_rank}"
logger.info(f"rank: {my_rank}, size: {my_size}")
import time
from transformers import AutoTokenizer
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import uvicorn
import asyncio, uuid
from typing import Dict, List, Optional, Any, Callable, Union
import argparse
class PromptRequest(BaseModel):
prompt: str
n_predict: Optional[int] = 256
req_type: str = 'completion'
from openai.types.chat import ChatCompletionMessageParam
class ChatCompletionRequest(BaseModel):
messages: List[ChatCompletionMessageParam]
model: str
max_tokens: Optional[int] = None
stream: Optional[bool] = False
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
stream: Optional[bool] = False
empty_req = PromptRequest(prompt="", n_predict=0)
app = FastAPI()
global tokenizer
global local_model
request_queue: asyncio.Queue = asyncio.Queue()
result_dict: Dict[str, str] = {}
streamer_dict = {}
local_rank = my_rank
from openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice,
ChatCompletionResponse,
ChatMessage,
DeltaMessage,
CompletionResponseChoice,
CompletionResponse,
CompletionResponseStreamChoice,
CompletionStreamResponse,
)
async def chat_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
while True:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
# print(remain)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=delta_text),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
index = index + 1
if remain == 0:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=None),
logprobs=None,
finish_reason="length")
chunk = ChatCompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
break
else:
await asyncio.sleep(0)
local_model.streamer.pop(request_id, None)
async def completion_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
while True:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
# print(remain)
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta_text,
logprobs=None,
finish_reason=None)
chunk = CompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
index = index + 1
if remain == 0:
choice_data = CompletionResponseStreamChoice(
index=index,
text="",
logprobs=None,
finish_reason="length")
chunk = CompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
break
else:
await asyncio.sleep(0)
local_model.streamer.pop(request_id, None)
async def generator(local_model, delta_text_queue, request_id):
while True:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
yield delta_text
if remain == 0:
break
else:
await asyncio.sleep(0)
local_model.streamer.pop(request_id, None)
@app.post("/generate/")
async def generate(prompt_request: PromptRequest):
request_id = str(uuid.uuid4())
await local_model.waiting_requests.put((request_id, prompt_request))
while True:
await asyncio.sleep(0)
cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None:
output_str = []
async for item in generator(local_model, cur_streamer, request_id):
output_str.append(item)
return request_id, "".join(output_str)
async def generate_stream(prompt_request: PromptRequest):
request_id = str(uuid.uuid4()) + "stream"
await local_model.waiting_requests.put((request_id, prompt_request))
while True:
await asyncio.sleep(0)
cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None:
if prompt_request.req_type == 'completion':
cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
elif prompt_request.req_type == 'chat':
cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
else:
invalidInputError(False, "Invalid Request Type.")
return request_id, StreamingResponse(
content=cur_generator, media_type="text/event-stream"
)
@app.post("/generate_stream/")
async def generate_stream_api(prompt_request: PromptRequest):
request_id, result = await generate_stream(prompt_request)
return result
DEFAULT_SYSTEM_PROMPT = """\
"""
def get_prompt(messages) -> str:
prompt = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else:
raise ValueError(f"Unknown role: {role}")
return prompt.strip()
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
prompt_request = PromptRequest(
prompt=get_prompt(request.messages),
n_predict=n_predict,
req_type="chat"
)
if request.stream:
request_id, result = await generate_stream(prompt_request)
else:
request_id, result = await generate(prompt_request)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=result),
logprobs=None,
finish_reason="length")
result = ChatCompletionResponse(
id=request_id,
choices=[choice_data],
model=model_name)
return result
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
prompt_request = PromptRequest(
prompt=request.prompt,
n_predict=n_predict,
req_type="completion"
)
if request.stream:
request_id, result = await generate_stream(prompt_request)
else:
request_id, result = await generate(prompt_request)
choice_data = CompletionResponseChoice(
index=0,
text=result,
logprobs=None,
finish_reason="length")
result = CompletionResponse(
id=request_id,
choices=[choice_data],
model=model_name)
return result
async def process_requests(local_model, result_dict):
while True:
await asyncio.sleep(0)
await local_model.process_step(tokenizer, result_dict)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests(local_model, result_dict))
async def main():
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default='sym_int4',
help='The quantization type the model will convert to.')
parser.add_argument('--port', type=int, default=8000,
help='The port number on which the server will run.')
parser.add_argument('--max-num-seqs', type=int, default=8,
help='Max num sequences in a batch.')
parser.add_argument('--max-prefilled-seqs', type=int, default=0,
help='Max num sequences in a batch during prefilling.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
max_num_seqs = args.max_num_seqs
max_prefilled_seqs = args.max_prefilled_seqs
# serialize model initialization so that we do not run out of CPU memory
for i in range(my_size):
if my_rank == i:
logger.info("start model initialization")
global local_model
local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
logger.info("model initialized")
dist.barrier()
# Load tokenizer
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if local_rank == 0:
config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
else:
while True:
await asyncio.sleep(0)
await local_model.process_step(tokenizer, result_dict)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -50,7 +50,14 @@ pip install transformers==4.40.0
pip install trl==0.8.1
```
### 2. Run pipeline parallel serving on multiple GPUs
### 2-1. Run ipex-llm serving on one GPU card
```bash
# Need to set NUM_GPUS=1 and MODEL_PATH in run.sh first
bash run.sh
```
### 2-2. Run pipeline parallel serving on multiple GPUs
```bash
# Need to set MODEL_PATH in run.sh first
@ -76,7 +83,7 @@ export http_proxy=
export https_proxy=
curl -X 'POST' \
'http://127.0.0.1:8000/generate/' \
'http://127.0.0.1:8000/generate' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
@ -99,7 +106,7 @@ Please change the test url accordingly.
```bash
# set t/c to the number of concurrencies to test full throughput.
wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate/ --timeout 1m
wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate --timeout 1m
```
## 5. Using the `benchmark.py` Script

View file

@ -0,0 +1,78 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch.distributed as dist
from ipex_llm.transformers import init_pipeline_parallel, PPModelWorker
from ipex_llm.serving.fastapi import FastApp
from transformers.utils import logging
from transformers import AutoTokenizer
import uvicorn
import asyncio
from typing import Dict
import argparse
logger = logging.get_logger(__name__)
init_pipeline_parallel()
my_rank = dist.get_rank()
my_size = dist.get_world_size()
device = f"xpu:{my_rank}"
logger.info(f"rank: {my_rank}, size: {my_size}")
result_dict: Dict[str, str] = {}
local_rank = my_rank
async def main():
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging Pipeline-Parallel')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default='sym_int4',
help='The quantization type the model will convert to.')
parser.add_argument('--port', type=int, default=8000,
help='The port number on which the server will run.')
parser.add_argument('--max-num-seqs', type=int, default=8,
help='Max num sequences in a batch.')
parser.add_argument('--max-prefilled-seqs', type=int, default=0,
help='Max num sequences in a batch during prefilling.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
max_num_seqs = args.max_num_seqs
max_prefilled_seqs = args.max_prefilled_seqs
# serialize model initialization so that we do not run out of CPU memory
for i in range(my_size):
if my_rank == i:
logger.info("start model initialization")
local_model = PPModelWorker(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
logger.info("model initialized")
dist.barrier()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
myapp = FastApp(local_model, tokenizer)
if local_rank == 0:
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
else:
while True:
await asyncio.sleep(0)
await local_model.process_step(tokenizer, result_dict)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -40,4 +40,9 @@ export LOW_BIT="fp8"
export MAX_NUM_SEQS="4"
export MAX_PREFILLED_SEQS=0
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT --max-num-seqs $MAX_NUM_SEQS --max-prefilled-seqs $MAX_PREFILLED_SEQS
if [[ $NUM_GPUS -eq 1 ]]; then
export ZE_AFFINITY_MASK=0
python serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT
else
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT --max-num-seqs $MAX_NUM_SEQS --max-prefilled-seqs $MAX_PREFILLED_SEQS
fi

View file

@ -0,0 +1,53 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from transformers.utils import logging
import time
from transformers import AutoTokenizer
import uvicorn
import asyncio
import argparse
from ipex_llm.serving.fastapi import FastApp
from ipex_llm.serving.fastapi import ModelWorker
logger = logging.get_logger(__name__)
async def main():
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging ipex-llm')
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default='sym_int4',
help='The quantization type the model will convert to.')
parser.add_argument('--port', type=int, default=8000,
help='The port number on which the server will run.')
args = parser.parse_args()
model_path = args.repo_id_or_model_path
low_bit = args.low_bit
local_model = ModelWorker(model_path, low_bit)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
myapp = FastApp(local_model, tokenizer)
config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
await server.serve()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,18 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .api_server import FastApp
from .model_worker import ModelWorker

View file

@ -0,0 +1,315 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from ipex_llm.utils.common import invalidInputError
from transformers.utils import logging
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ipex_llm.utils.common import invalidInputError
import asyncio
import uuid
from typing import List, Optional, Union, Dict
result_dict: Dict[str, str] = {}
logger = logging.get_logger(__name__)
class PromptRequest(BaseModel):
prompt: str
n_predict: Optional[int] = 256
req_type: str = 'completion'
class ChatCompletionRequest(BaseModel):
messages: List[ChatCompletionMessageParam]
model: str
max_tokens: Optional[int] = None
stream: Optional[bool] = False
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
stream: Optional[bool] = False
app = FastAPI()
global tokenizer
global local_model
class FastApp():
def __init__(self, model, mytokenizer):
global tokenizer
global local_model
local_model = model
tokenizer = mytokenizer
self.app = app
from .openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice,
ChatCompletionResponse,
ChatMessage,
DeltaMessage,
CompletionResponseChoice,
CompletionResponse,
CompletionResponseStreamChoice,
CompletionStreamResponse,
)
def get_queue_next_token(delta_text_queue):
timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
delta_text = delta_text_queue.text_queue.get(timeout=timeout)
if delta_text is None:
remain = 0
else:
remain = 1
return delta_text, remain
async def chat_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
while True:
if not hasattr(delta_text_queue, 'empty'):
delta_text, remain = get_queue_next_token(delta_text_queue)
else:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
else:
await asyncio.sleep(0)
continue
if remain == 0 and delta_text is not None or remain != 0:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=delta_text),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
index = index + 1
if remain == 0:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=None),
logprobs=None,
finish_reason="length")
chunk = ChatCompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
break
local_model.streamer.pop(request_id, None)
async def completion_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
while True:
if not hasattr(delta_text_queue, 'empty'):
delta_text, remain = get_queue_next_token(delta_text_queue)
else:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
else:
await asyncio.sleep(0)
continue
if remain == 0 and delta_text is not None or remain != 0:
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta_text,
logprobs=None,
finish_reason=None)
chunk = CompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
index = index + 1
if remain == 0:
choice_data = CompletionResponseStreamChoice(
index=index,
text="",
logprobs=None,
finish_reason="length")
chunk = CompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
break
local_model.streamer.pop(request_id, None)
async def generator(local_model, delta_text_queue, request_id):
while True:
if not hasattr(delta_text_queue, 'empty'):
delta_text, remain = get_queue_next_token(delta_text_queue)
if delta_text is None:
break
else:
yield delta_text
else:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
yield delta_text
if remain == 0:
break
else:
await asyncio.sleep(0)
continue
local_model.streamer.pop(request_id, None)
@app.post("/generate")
async def generate(prompt_request: PromptRequest):
request_id = str(uuid.uuid4())
await local_model.waiting_requests.put((request_id, prompt_request))
while True:
await asyncio.sleep(0)
cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None:
output_str = []
async for item in generator(local_model, cur_streamer, request_id):
output_str.append(item)
return request_id, "".join(output_str)
@app.post("/generate_stream")
async def generate_stream_api(prompt_request: PromptRequest):
request_id, result = await generate_stream(prompt_request)
return result
async def generate_stream(prompt_request: PromptRequest):
request_id = str(uuid.uuid4()) + "stream"
await local_model.waiting_requests.put((request_id, prompt_request))
while True:
await asyncio.sleep(0)
cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None:
if prompt_request.req_type == 'completion':
cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
elif prompt_request.req_type == 'chat':
cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
else:
invalidInputError(False, "Invalid Request Type.")
return request_id, StreamingResponse(
content=cur_generator, media_type="text/event-stream"
)
def get_prompt(messages) -> str:
prompt = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else:
invalidInputError(False, f"Unknown role: {role}")
return prompt.strip()
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
prompt_request = PromptRequest(
prompt=get_prompt(request.messages),
n_predict=n_predict,
req_type="chat"
)
if request.stream:
request_id, result = await generate_stream(prompt_request)
else:
request_id, result = await generate(prompt_request)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=result),
logprobs=None,
finish_reason="length")
result = ChatCompletionResponse(
id=request_id,
choices=[choice_data],
model=model_name)
return result
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
prompt_request = PromptRequest(
prompt=request.prompt,
n_predict=n_predict,
req_type="completion"
)
if request.stream:
request_id, result = await generate_stream(prompt_request)
else:
request_id, result = await generate(prompt_request)
choice_data = CompletionResponseChoice(
index=0,
text=result,
logprobs=None,
finish_reason="length")
result = CompletionResponse(
id=request_id,
choices=[choice_data],
model=model_name)
return result
@app.on_event("startup")
async def startup_event():
asyncio.create_task(process_requests(local_model, result_dict))
async def process_requests(local_model, result_dict):
while True:
await asyncio.sleep(0)
await local_model.process_step(tokenizer, result_dict)

View file

@ -0,0 +1,82 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from transformers.utils import logging
import time
import asyncio
from transformers import TextIteratorStreamer
logger = logging.get_logger(__name__)
class ModelWorker:
def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
self.dtype = torch_dtype
start = time.perf_counter()
model = self.load_model(checkpoint, low_bit)
from ipex_llm.utils.benchmark_util import BenchmarkWrapper
self.model = BenchmarkWrapper(model, do_print=True)
end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s")
self.waiting_requests = asyncio.Queue()
self.streamer = {}
self.model_name = checkpoint
def load_model(self, model_path, low_bit='sym_int4'):
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try:
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
except:
model = AutoModel.from_pretrained(model_path,
load_in_low_bit=low_bit,
torch_dtype=self.dtype,
optimize_model=True,
trust_remote_code=True,
use_cache=True,)
model = model.eval().to("xpu")
return model
async def add_request(self, tokenizer):
if self.waiting_requests.empty():
return
tmp_result = await self.waiting_requests.get()
request_id, prompt_request = tmp_result
plain_texts = prompt_request.prompt
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to('xpu')
max_tokens = prompt_request.n_predict
return input_ids, max_tokens, request_id
@torch.no_grad()
async def process_step(self, tokenizer, result_dict):
if not self.waiting_requests.empty():
input_ids, max_tokens, request_id = await self.add_request(tokenizer)
self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
def model_generate():
self.model.generate(input_ids,
streamer=self.streamer[request_id], max_new_tokens=max_tokens)
torch.xpu.empty_cache()
torch.xpu.synchronize()
from threading import Thread
t1 = Thread(target=model_generate)
t1.start()

View file

@ -15,6 +15,7 @@
#
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
@ -22,11 +23,14 @@ import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
from ipex_llm.utils.common import invalidInputError
# from vllm.sampling_params import SamplingParams
def random_uuid() -> str:
return str(uuid.uuid4().hex)
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
@ -127,8 +131,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
@ -179,7 +183,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
invalidInputError(False,
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@ -232,8 +236,8 @@ class CompletionRequest(OpenAIBaseModel):
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
description=(
"Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
@ -279,7 +283,7 @@ class CompletionRequest(OpenAIBaseModel):
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
invalidInputError(False,
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data

View file

@ -22,4 +22,4 @@ from .model import AutoModelForCausalLM, AutoModel, AutoModelForSeq2SeqLM, \
AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
AutoModelForTokenClassification
from .modelling_bigdl import *
from .pipeline_parallel import init_pipeline_parallel, ModelRunner
from .pipeline_parallel import init_pipeline_parallel, PPModelWorker

View file

@ -468,7 +468,7 @@ def make_attention_mask(prompt_lengths):
return attention_mask
class ModelRunner:
class PPModelWorker:
"""Implementation for pipeline parallel multi-stage serving."""
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs,
torch_dtype=torch.float16):