LLM: Add /generate_stream endpoint for Pipeline-Parallel-FastAPI example (#11187)

Add /generate_stream and OpenAI-formatted endpoint for Pipeline-Parallel-FastAPI example
This commit is contained in:
Xiangyu Tian 2024-06-14 15:15:32 +08:00 committed by GitHub
parent 9e4d87a696
commit 4359ab3172
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1041 additions and 18 deletions

View file

@ -18,7 +18,8 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
# configures OneAPI environment variables # configures OneAPI environment variables
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
pip install mpi4py fastapi uvicorn pip install mpi4py fastapi uvicorn openai
pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
pip install transformers==4.31.0 # for llama2 models pip install transformers==4.31.0 # for llama2 models
@ -69,3 +70,28 @@ Please change the test url accordingly.
# set t/c to the number of concurrencies to test full throughput. # set t/c to the number of concurrencies to test full throughput.
wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate/ --timeout 1m wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate/ --timeout 1m
``` ```
## 5. Using the `benchmark.py` Script
The `benchmark.py` script is designed to evaluate the performance of a streaming service by measuring response times and other relevant metrics. Below are the details on how to use the script effectively:
### Command Line Arguments
- `--prompt_length`: Specifies the length of the prompt used in the test. Acceptable values are `32`, `128`, `1024`, and `2048`.
- `--max_concurrent_requests`: Defines the levels of concurrency for the requests. You can specify multiple values to test different levels of concurrency in one run.
- `--max_new_tokens`: Sets the maximum number of new tokens that the model will generate per request. Default is `128`.
### Usage Example
You can run the script with specific settings for prompt length, concurrent requests, and max new tokens by using the following command:
```bash
python benchmark.py --prompt_length 1024 --max_concurrent_requests 1 2 3 --max_new_tokens 128
```
This command sets the prompt length to 1024, tests concurrency levels of 1, 2, and 3, and configures the model to generate up to 128 new tokens per request. The results are saved in log files named according to the concurrency level (1.log, 2.log, 3.log).
## 6. Gradio Web UI
```bash
python ./gradio_webui.py -m Llama-2-13b-chat-hf
```

View file

@ -0,0 +1,270 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import requests
import time
from concurrent.futures import ThreadPoolExecutor
import concurrent
import numpy as np
from tqdm import tqdm
import json
import argparse
from typing import List, Tuple
# Execute single request
def perform_request(session, url, payload, headers):
start_time = time.perf_counter()
with session.post(url, json=payload, headers=headers, stream=True) as response:
response.raise_for_status()
first_token_time = None
last_token_time = 0
first_token_inference_time = None
next_token_inference_time = None
next_token_time = []
i = 0
for line in response.iter_lines():
token_time = time.perf_counter() - start_time
if line:
data = line.decode("utf-8").strip()
i = i + 1
try:
json_data = json.loads(data)
if json_data["message"] is not None:
if first_token_time is None:
first_token_time = token_time
else:
next_token_time.append(token_time - last_token_time)
last_token_time = token_time
except json.JSONDecodeError:
pass
end_time = time.perf_counter()
return (
first_token_time,
np.mean(next_token_time),
end_time - start_time,
first_token_inference_time,
next_token_inference_time,
)
def extend_list_to_length(lst, target_length):
if target_length <= len(lst):
return lst[:]
times = target_length // len(lst)
remainder = target_length % len(lst)
extended_list = lst * times + lst[:remainder]
return extended_list
def benchmark(
llm_urls,
prompt,
num_requests,
max_concurrent_requests,
max_tokens,
prompt_length,
is_warmup=False,
):
headers = {"Content-Type": "application/json"}
first_token_latencies = []
next_token_latencies = []
total_responce_times = []
first_token_inference_times = []
next_token_inference_times = []
cur_url_index = 0
with requests.Session() as session:
with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor:
llm_url = llm_urls[cur_url_index]
cur_url_index = (cur_url_index + 1) % len(llm_urls)
cur_llm_urls = extend_list_to_length(llm_urls, max_concurrent_requests)
cur_len = len(cur_llm_urls)
payload = {
"prompt": prompt,
"n_predict": max_tokens,
}
futures = [
executor.submit(
perform_request,
session,
cur_llm_urls[index % cur_len],
payload,
headers,
)
for index in range(num_requests)
]
start_time = time.perf_counter()
if is_warmup:
phase = "Warm Up"
else:
phase = "Benchmarking"
with tqdm(total=num_requests, desc=phase, unit="req", ncols=100) as pbar:
for future in concurrent.futures.as_completed(futures):
try:
(
first_token_latency,
next_token_latency,
total_responce_time,
first_token_inference_time,
next_token_inference_time,
) = future.result()
first_token_latencies.append(first_token_latency)
next_token_latencies.append(next_token_latency)
total_responce_times.append(total_responce_time)
if first_token_inference_time:
first_token_inference_times.append(
first_token_inference_time
)
if next_token_inference_time:
next_token_inference_times.append(next_token_inference_time)
except Exception as e:
print(f"Request failed: {e}")
pbar.update(1)
if is_warmup:
return
total_time = time.perf_counter() - start_time
log_file = f"{max_concurrent_requests}.log"
with open(log_file, "w") as file:
print(
f"Total time for {num_requests} requests with {max_concurrent_requests} concurrent requests: {total_time} seconds.",
file=file,
)
print(
f"Average response time: {np.mean(total_responce_times)}", file=file
)
print(
f"Token throughput: {num_requests * max_tokens / total_time}",
file=file,
)
print(
f"Total token throughput: {(max_tokens + prompt_length) * num_requests / total_time}",
file=file,
)
print(file=file)
if first_token_latencies:
average_first_token_latency = sum(first_token_latencies) / len(
first_token_latencies
)
p90_first_token_latency = np.percentile(first_token_latencies, 90)
p95_first_token_latency = np.percentile(first_token_latencies, 95)
# average_first_token_inference_latency = np.mean(
# first_token_inference_times
# )
print(
f"Average first token latency: {average_first_token_latency * 1000} milliseconds.",
file=file,
)
print(
f"P90 first token latency: {p90_first_token_latency * 1000} milliseconds.",
file=file,
)
print(
f"P95 first token latency: {p95_first_token_latency * 1000} milliseconds.",
file=file,
)
# print(
# f"Average first token inference latency: {average_first_token_inference_latency * 1000} milliseconds.",
# file=file,
# )
print(file=file)
if next_token_latencies:
average_next_token_latency = sum(next_token_latencies) / len(
next_token_latencies
)
p90_next_token_latency = np.percentile(next_token_latencies, 90)
p95_next_token_latency = np.percentile(next_token_latencies, 95)
# average_next_token_inference_latency = np.mean(
# next_token_inference_times
# )
print(
f"Average next token latency: {average_next_token_latency * 1000} milliseconds.",
file=file,
)
print(
f"P90 next token latency: {p90_next_token_latency * 1000} milliseconds.",
file=file,
)
print(
f"P95 next token latency: {p95_next_token_latency * 1000} milliseconds.",
file=file,
)
# print(
# f"Average next token inference latency: {average_next_token_inference_latency * 1000} milliseconds.",
# file=file,
# )
print(file=file)
LLM_URLS = [f"http://localhost:{PORT}/generate_stream/" for PORT in [8000]]
parser = argparse.ArgumentParser(description="Set prompt length.")
parser.add_argument(
"--prompt_length",
type=int,
choices=[32, 128, 1024, 2048],
default=1024,
help="Length of the prompt: 32, 1024, or 2048",
)
parser.add_argument(
"--max_concurrent_requests",
type=int,
nargs="+",
default=[1, 2, 4, 5, 6],
help="List of maximum concurrent requests to test.",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=128,
help="Maximum number of new tokens that the model will generate per request.",
)
args = parser.parse_args()
PROMPT_LENGTH = args.prompt_length
PROMPT = open(f"prompt/{PROMPT_LENGTH}.txt", "r").read()
MAX_TOKENS = args.max_new_tokens
for MAX_CONCURRENT_REQUESTS in args.max_concurrent_requests:
NUM_WARMUP = 5 * MAX_CONCURRENT_REQUESTS
NUM_REQUESTS = 10 * MAX_CONCURRENT_REQUESTS
# warm up
benchmark(
LLM_URLS,
PROMPT,
NUM_WARMUP,
MAX_CONCURRENT_REQUESTS,
MAX_TOKENS,
PROMPT_LENGTH,
is_warmup=True,
)
benchmark(LLM_URLS, PROMPT, NUM_REQUESTS, MAX_CONCURRENT_REQUESTS, MAX_TOKENS, PROMPT_LENGTH)

View file

@ -0,0 +1,69 @@
import argparse
import gradio as gr
from openai import OpenAI
# Argument parser setup
parser = argparse.ArgumentParser(
description='Chatbot Interface with Customizable Parameters')
parser.add_argument('--model-url',
type=str,
default='http://localhost:8000/v1',
help='Model URL')
parser.add_argument('-m',
'--model',
type=str,
required=True,
help='Model name for the chatbot')
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
# Parse the arguments
args = parser.parse_args()
# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = args.model_url
# Create an OpenAI client to interact with the API server
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def predict(message, history):
# Convert chat history to OpenAI format
history_openai_format = [{
"role": "system",
"content": "You are a great ai assistant."
}]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({
"role": "assistant",
"content": assistant
})
history_openai_format.append({"role": "user", "content": message})
# Create a chat completion request and send it to the API server
stream = client.chat.completions.create(
model=args.model, # Model name to use
messages=history_openai_format, # Chat history
stream=True, # Stream response
)
# Read and return generated text from response stream
partial_message = ""
for chunk in stream:
# import pdb
# pdb.set_trace()
# partial_message += (chunk.delta['content'] or "")
partial_message += (chunk.choices[0].delta.content or "")
yield partial_message
# Create and launch a chat interface with Gradio
gr.ChatInterface(predict).queue().launch(server_name=args.host,
server_port=args.port,
share=True)

View file

@ -0,0 +1,367 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
# from vllm.sampling_params import SamplingParams
def random_uuid() -> str:
return str(uuid.uuid4().hex)
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
class ErrorResponse(OpenAIBaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: bool = False
class ModelCard(OpenAIBaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(OpenAIBaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(OpenAIBaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-chat-completion-extra-params
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
# doc: end-completion-extra-params
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: Optional[UsageInfo] = Field(default=None)
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(OpenAIBaseModel):
role: str
content: str
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: Optional[UsageInfo] = Field(default=None)
class DeltaMessage(OpenAIBaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class ChatCompletionStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)

View file

@ -289,6 +289,12 @@ class ModelRunner:
self.send_buff = None self.send_buff = None
self.dict_lock = threading.Lock() self.dict_lock = threading.Lock()
self.streamer = {}
self.token_cache = {}
self.print_len = {}
self.is_finish = {}
self.model_name = checkpoint
# def generate(self, input_ids=None, max_tokens=5, attention_mask=None): # def generate(self, input_ids=None, max_tokens=5, attention_mask=None):
# times = [] # times = []
@ -422,7 +428,7 @@ class ModelRunner:
if cur_batch is None: if cur_batch is None:
if not self.waiting_requests.empty(): if not self.waiting_requests.empty():
# await asyncio.sleep(0.01) await asyncio.sleep(0.01)
cur_batch = await self.add_request(tokenizer) cur_batch = await self.add_request(tokenizer)
cur_input = self.input_ids_dict[cur_batch.batch_id] cur_input = self.input_ids_dict[cur_batch.batch_id]
else: else:
@ -447,6 +453,44 @@ class ModelRunner:
# cur_batch.input_len += 1 # cur_batch.input_len += 1
cur_batch.input_len = 1 cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue()
if next_ids[index].int() == tokenizer.eos_token_id:
remain = 0
self.is_finish[request_id] = True
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
text = tokenizer.decode(self.token_cache[request_id])
if text.endswith("\n"):
printable_text = text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
else:
printable_text = text[self.print_len[request_id] : text.rfind(" ") + 1]
self.print_len[request_id] += len(printable_text)
if remain > 0:
await self.streamer[request_id].put((remain, printable_text))
else:
printable_text = printable_text + text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
await self.streamer[request_id].put((remain, printable_text))
if len(self.tokens[cur_id]) >= cur_batch.max_tokens: if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
# Finish a batch # Finish a batch
# logger.info(self.tokens[cur_id]) # logger.info(self.tokens[cur_id])
@ -509,3 +553,27 @@ class ModelRunner:
self.on_going_batches[:-1] = self.on_going_batches[1:] self.on_going_batches[:-1] = self.on_going_batches[1:]
self.on_going_batches[self.world_size - 1] = cur_batch self.on_going_batches[self.world_size - 1] = cur_batch
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False

View file

@ -3,7 +3,10 @@ import torch.nn.parallel
import torch.distributed as dist import torch.distributed as dist
import os import os
import ipex_llm
from ipex_llm.utils.common import invalidInputError
import oneccl_bindings_for_pytorch import oneccl_bindings_for_pytorch
import json
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -20,11 +23,12 @@ logger.info(f"rank: {my_rank}, size: {my_size}")
import time import time
from transformers import AutoTokenizer, AutoConfig, LlamaTokenizer from transformers import AutoTokenizer, AutoConfig, LlamaTokenizer
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import uvicorn import uvicorn
import asyncio, uuid import asyncio, uuid
from typing import Dict, List, Optional from typing import Dict, List, Optional, Any, Callable, Union
import argparse import argparse
def get_int_from_env(env_keys, default): def get_int_from_env(env_keys, default):
@ -38,8 +42,22 @@ def get_int_from_env(env_keys, default):
class PromptRequest(BaseModel): class PromptRequest(BaseModel):
prompt: str prompt: str
n_predict: int = 32 n_predict: Optional[int] = 256
req_type: str = 'completion'
from openai.types.chat import ChatCompletionMessageParam
class ChatCompletionRequest(BaseModel):
messages: List[ChatCompletionMessageParam]
model: str
max_tokens: Optional[int] = None
stream: Optional[bool] = False
class CompletionRequest(BaseModel):
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
max_tokens: Optional[int] = None
stream: Optional[bool] = False
empty_req = PromptRequest(prompt="", n_predict=0) empty_req = PromptRequest(prompt="", n_predict=0)
@ -49,8 +67,112 @@ global local_model
request_queue: asyncio.Queue = asyncio.Queue() request_queue: asyncio.Queue = asyncio.Queue()
result_dict: Dict[str, str] = {} result_dict: Dict[str, str] = {}
streamer_dict = {}
local_rank = my_rank local_rank = my_rank
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
from openai_protocol import (
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatCompletionResponseChoice,
ChatCompletionResponse,
ChatMessage,
DeltaMessage,
CompletionResponseChoice,
CompletionResponse,
CompletionResponseStreamChoice,
CompletionStreamResponse,
)
async def chat_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
while True:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
# print(remain)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=delta_text),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
index = index + 1
if remain == 0:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(role="assistant", content=None),
logprobs=None,
finish_reason="length")
chunk = ChatCompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
break
else:
await asyncio.sleep(0)
local_model.streamer.pop(request_id, None)
async def completion_stream_generator(local_model, delta_text_queue, request_id):
model_name = local_model.model_name
index = 0
while True:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
# print(remain)
choice_data = CompletionResponseStreamChoice(
index=index,
text=delta_text,
logprobs=None,
finish_reason=None)
chunk = CompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
index = index + 1
if remain == 0:
choice_data = CompletionResponseStreamChoice(
index=index,
text=None,
logprobs=None,
finish_reason="length")
chunk = CompletionStreamResponse(
id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
break
else:
await asyncio.sleep(0)
local_model.streamer.pop(request_id, None)
async def generator(local_model, delta_text_queue, request_id):
while True:
if not delta_text_queue.empty():
with local_model.dict_lock:
remain, delta_text = await delta_text_queue.get()
yield delta_text
if remain == 0:
break
else:
await asyncio.sleep(0)
# streamer_dict.pop(request_id, None)
local_model.streamer.pop(request_id, None)
@app.post("/generate/") @app.post("/generate/")
@ -58,16 +180,106 @@ async def generate(prompt_request: PromptRequest):
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
await local_model.waiting_requests.put((request_id, prompt_request)) await local_model.waiting_requests.put((request_id, prompt_request))
while True: while True:
if request_id in result_dict: await asyncio.sleep(0)
with local_model.dict_lock: cur_streamer = local_model.streamer.get(request_id, None)
output_str = result_dict[request_id] if cur_streamer is not None:
if len(output_str) == 0: output_str = []
logger.info(f"Why? {request_id}") async for item in generator(local_model, cur_streamer, request_id):
# await asyncio.sleep(0.1) output_str.append(item)
# continue return request_id, "".join(output_str)
result_dict.pop(request_id)
return {"generated_text": output_str}
await asyncio.sleep(0) @app.post("/generate_stream/")
async def generate_stream(prompt_request: PromptRequest):
request_id = str(uuid.uuid4()) + "stream"
await local_model.waiting_requests.put((request_id, prompt_request))
while True:
await asyncio.sleep(0)
cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None:
if prompt_request.req_type == 'completion':
cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
elif prompt_request.req_type == 'chat':
cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
else:
invalidInputError(False, "Invalid Request Type.")
return request_id, StreamingResponse(
content=cur_generator, media_type="text/event-stream"
)
DEFAULT_SYSTEM_PROMPT = """\
"""
def get_prompt(messages) -> str:
prompt = ""
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
prompt += f"<<SYS>>\n{content}\n<</SYS>>\n\n"
elif role == "user":
prompt += f"[INST] {content} [/INST] "
elif role == "assistant":
prompt += f"{content} "
else:
raise ValueError(f"Unknown role: {role}")
return prompt.strip()
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
prompt_request = PromptRequest(
prompt=get_prompt(request.messages),
n_predict=n_predict,
req_type="chat"
)
if request.stream:
request_id, result = await generate_stream(prompt_request)
else:
request_id, result = await generate(prompt_request)
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=result),
logprobs=None,
finish_reason="length")
result = ChatCompletionResponse(
id=request_id,
choices=[choice_data],
model=model_name)
return result
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
model_name = local_model.model_name
if request.max_tokens is None:
n_predict = 256
else:
n_predict = request.max_tokens
prompt_request = PromptRequest(
prompt=request.prompt,
n_predict=n_predict,
req_type="completion"
)
if request.stream:
request_id, result = await generate_stream(prompt_request)
else:
request_id, result = await generate(prompt_request)
choice_data = CompletionResponseChoice(
index=0,
text=result,
logprobs=None,
finish_reason="length")
result = CompletionResponse(
id=request_id,
choices=[choice_data],
model=model_name)
return result
def generate_text(prompt: List[str], n_predict = 32): def generate_text(prompt: List[str], n_predict = 32):

View file

@ -0,0 +1 @@
Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. However, her parents were always telling her to stay close to home, to be careful, and to avoid any danger. But the little girl was stubborn, and she wanted to see what was on the other side of the mountain. So she sneaked out of the house one night, leaving a note for her parents, and set off on her journey. As she climbed the mountain, the little girl felt a sense of excitement and wonder. She had never been this far away from home before, and she couldnt wait to see what she would find on the other side. She climbed higher and higher, her lungs burning from the thin air, until she finally reached the top of the mountain. And there, she found a beautiful meadow filled with wildflowers and a sparkling stream. The little girl danced and played in the meadow, feeling free and alive. She knew she had to return home eventually, but for now, she was content to enjoy her adventure. As the sun began to set, the little girl reluctantly made her way back down the mountain, but she knew that she would never forget her adventure and the joy of discovering something new and exciting. And whenever she felt scared or unsure, she would remember the thrill of climbing the mountain and the beauty of the meadow on the other side, and she would know that she could face any challenge that came her way, with courage and determination. She carried the memories of her journey in her heart, a constant reminder of the strength she possessed. The little girl returned home to her worried parents, who had discovered her note and anxiously awaited her arrival. They scolded her for disobeying their instructions and venturing into the unknown. But as they looked into her sparkling eyes and saw the glow on her face, their anger softened. They realized that their little girl had grown, that she had experienced something extraordinary. The little girl shared her tales of the mountain and the meadow with her parents, painting vivid pictures with her words. She spoke of the breathtaking view from the mountaintop, where the world seemed to stretch endlessly before her. She described the delicate petals of the wildflowers, vibrant hues that danced in the gentle breeze. And she recounted the soothing melody of the sparkling stream, its waters reflecting the golden rays of the setting sun. Her parents listened intently, captivated by her story. They realized that their daughter had discovered a part of herself on that journey—a spirit of curiosity and a thirst for exploration. They saw that she had learned valuable lessons about independence, resilience, and the beauty that lies beyond ones comfort zone. From that day forward, the little girls parents encouraged her to pursue her dreams and embrace new experiences. They understood that while there were risks in the world, there were also rewards waiting to be discovered. They supported her as she continued to embark on adventures, always reminding her to stay safe but never stifling her spirit. As the years passed, the little girl grew into a remarkable woman, fearlessly exploring the world and making a difference wherever she went. The lessons she had learned on that fateful journey stayed with her, guiding her through challenges and inspiring her to live life to the fullest. And so, the once timid little girl became a symbol of courage and resilience, a reminder to all who knew her that the greatest joys in life often lie just beyond the mountains we fear to climb. Her story spread far and wide, inspiring others to embrace their own journeys and discover the wonders that awaited them. In the end, the little girls adventure became a timeless tale, passed down through generations, reminding us all that sometimes, the greatest rewards come to those who dare to step into the unknown and follow their hearts. With each passing day, the little girls story continued to inspire countless individuals, igniting a spark within their souls and encouraging them to embark on their own extraordinary adventures. The tale of her bravery and determination resonated deeply with people from all walks of life, reminding them of the limitless possibilities that awaited them beyond the boundaries of their comfort zones. People marveled at the little girls unwavering spirit and her unwavering belief in the power of dreams. They saw themselves reflected in her journey, finding solace in the knowledge that they too could overcome their fears and pursue their passions. The little girl's story became a beacon of hope, a testament to the human spirit

View file

@ -0,0 +1 @@
In a distant future, humanity has expanded across the galaxy, establishing colonies on numerous planets. The interstellar community thrives under the guidance of the United Galactic Federation, which ensures peace and prosperity. However, a new threat emerges from the unknown regions of space, challenging the stability and security of the galaxy. Brave explorers and seasoned warriors must unite to uncover the secrets of this mysterious force and protect the future of all sentient beings. Please continue the above story as long as possible, preferably more than 1000 tokens.

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1 @@
Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun

View file

@ -1,7 +1,12 @@
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
export no_proxy=localhost export no_proxy=localhost
export FI_PROVIDER=tcp export FI_PROVIDER=tcp
export OMP_NUM_THREADS=8 export OMP_NUM_THREADS=32
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
basekit_root=/opt/intel/oneapi
source $basekit_root/setvars.sh --force
source $basekit_root/ccl/latest/env/vars.sh --force
export USE_XETLA=OFF export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
@ -9,4 +14,6 @@ export TORCH_LLM_ALLREDUCE=0
export MODEL_PATH=YOUR_MODEL_PATH export MODEL_PATH=YOUR_MODEL_PATH
export NUM_GPUS=2 export NUM_GPUS=2
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 export BIGDL_QUANTIZE_KV_CACHE=1
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4