Integrate vllm (#9310)

* done

* Rename structure

* add models

* Add structure/sampling_params,sequence

* add input_metadata

* add outputs

* Add policy,logger

* add and update

* add parallelconfig back

* core/scheduler.py

* Add llm_engine.py

* Add async_llm_engine.py

* Add tested entrypoint

* fix minor error

* Fix everything

* fix kv cache view

* fix

* fix

* fix

* format&refine

* remove logger from repo

* try to add token latency

* remove logger

* Refine config.py

* finish worker.py

* delete utils.py

* add license

* refine

* refine sequence.py

* remove sampling_params.py

* finish

* add license

* format

* add license

* refine

* refine

* Refine line too long

* remove exception

* so dumb style-check

* refine

* refine

* refine

* refine

* refine

* refine

* add README

* refine README

* add warning instead error

* fix padding

* add license

* format

* format

* format fix

* Refine vllm dependency (#1)

vllm dependency clear

* fix licence

* fix format

* fix format

* fix

* adapt LLM engine

* fix

* add license

* fix format

* fix

* Moving README.md to the correct position

* Fix readme.md

* done

* guide for adding models

* fix

* Fix README.md

* Add new model readme

* remove ray-logic

* refactor arg_utils.py

* remove distributed_init_method logic

* refactor entrypoints

* refactor input_metadata

* refactor model_loader

* refactor utils.py

* refactor models

* fix api server

* remove vllm.stucture

* revert by txy 1120

* remove utils

* format

* fix license

* add bigdl model

* Refer to a specfic commit

* Change code base

* add comments

* add async_llm_engine comment

* refine

* formatted

* add worker comments

* add comments

* add comments

* fix style

* add changes

---------

Co-authored-by: xiangyuT <xiangyu.tian@intel.com>
Co-authored-by: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com>
Co-authored-by: leonardozcm <leonardo1997zcm@gmail.com>
This commit is contained in:
Guancheng Fu 2023-11-23 16:46:45 +08:00 committed by GitHub
parent 48fbb1eb94
commit bf579507c2
33 changed files with 6507 additions and 0 deletions

View file

@ -192,6 +192,8 @@ def trailing_whitespace(physical_line):
def trailing_blacklist_words(physical_line):
if noqa(physical_line):
return
if physical_line.find("assert ") != -1:
return 0, "Please don't use assert, use log4Error.invalidInputError instead"
if physical_line.find("raise ") != -1:

View file

@ -8,6 +8,7 @@ This folder contains examples of running BigDL-LLM on Intel CPU:
- [LangChain](LangChain): running LangChain applications on BigDL-LLM
- [Applications](Applications): running Transformers applications on BigDl-LLM
- [QLoRA-FineTuning](QLoRA-FineTuning): running QLoRA finetuning using BigDL-LLM on intel CPUs
- [vLLM-Serving](vLLM-Serving): running vLLM serving framework on Xeon Platforms (with BigDL-LLM low-bit optimized models)
## System Support
**Hardware**:

View file

@ -0,0 +1,91 @@
# Serving LLAMA models using vLLM on Intel platforms (experimental support)
This example demonstrates how to serving a llama2-7b model using BigDL-LLM 4 bits optimizations with xeon CPUs with adapted vLLM.
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.2.1.post1).
## Example: Serving llama2-7b using Xeon CPU
In this example, we will run Llama2-7b model using 48 cores in one socket and provide `OpenAI-compatible` interface for users.
### 1. Install
The original [vLLM](https://github.com/vllm-project/vllm) is designed to run with `CUDA` environment. To adapt vLLM into `Intel` platforms, install the dependencies like this:
```bash
# First create an conda environment
conda create -n bigdl-vllm python==3.9
conda activate bigdl-vllm
# Install dependencies
pip install --pre --upgrade bigdl-llm[all]
pip3 install psutil
pip3 install sentencepiece # Required for LLaMA tokenizer.
pip3 install numpy
pip3 install "torch==2.0.1"
pip3 install "transformers>=4.33.1" # Required for Code Llama.
pip3 install "xformers == 0.0.22"
pip3 install fastapi
pip3 install "uvicorn[standard]"
pip3 install "pydantic<2" # Required for OpenAI server.
```
### 2. Configures Recommending environment variables
```bash
source bigdl-llm-init
```
### 3. Offline inference/Service
#### Offline inference
To run offline inference using vLLM for a quick impression, use the following example:
```bash
#!/bin/bash
# Please first modify the MODEL_PATH in offline_inference.py
numactl -C 48-95 -m 1 python offline_inference.py
```
#### Service
To fully utilize the dynamic batching feature of the `vLLM`, you can send requests to the service using curl or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting all sequences are finished.
```bash
#!/bin/bash
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.examples.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000 \
--load-format 'auto' --device cpu --dtype bfloat16
```
Then you can access the api server using the following way:
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/MODEL_PATH/Llama-2-7b-chat-hf-bigdl/",
"prompt": "San Francisco is a",
"max_tokens": 128,
"temperature": 0
}' &
```
### 4. (Optional) Add a new model
Currently we have only supported llama-structure model (including `llama`, `vicuna`, `llama-2`, etc.). To use other model, you may need some adaption to the code.
#### 4.1 Add model code
Create or clone the Pytorch model code to `./models`.
#### 4.2 Rewrite the forward methods
Refering to `./models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.
#### 4.3 Register new model
Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `./models/model_loader.py`.

View file

@ -0,0 +1,57 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bigdl.llm.vllm.examples.llm import LLM
from bigdl.llm.vllm.structure.sampling_params import SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View file

@ -0,0 +1,107 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
#tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
#tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
#tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm tests
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

View file

@ -0,0 +1,417 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/config.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from transformers import AutoConfig, PretrainedConfig
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__)
class ModelConfig:
"""Configuration for the model.
Args:
model: Name or path of the huggingface model to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
device: The device to be used for the model. If None, we will default
to use CPU as the device.
"""
def __init__(
self,
model: str,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
download_dir: Optional[str],
load_format: str,
dtype: str,
seed: int,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
device: Optional[str] = 'cpu',
) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.download_dir = download_dir
self.load_format = load_format
self.seed = seed
self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.device = device
self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
if load_format not in [
"auto", "pt", "safetensors", "npcache", "dummy"
]:
invalidInputError(
False,
f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow"]:
invalidInputError(
False,
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
if self.quantization is None:
return
quantization = self.quantization.lower()
if quantization not in supported_quantization:
invalidInputError(
False,
f"Unknown quantization: {self.quantization}. Must be one of "
f"{supported_quantization}.")
self.quantization = quantization
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
total_num_attention_heads = self.hf_config.num_attention_heads
tensor_parallel_size = parallel_config.tensor_parallel_size
if total_num_attention_heads % tensor_parallel_size != 0:
invalidInputError(
False,
f"Total number of attention heads ({total_num_attention_heads})"
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
total_num_hidden_layers = self.hf_config.num_hidden_layers
pipeline_parallel_size = parallel_config.pipeline_parallel_size
if total_num_hidden_layers % pipeline_parallel_size != 0:
invalidInputError(
False,
f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
def get_head_size(self) -> int:
# FIXME(woosuk): This may not be true for all models.
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU worker."""
# For GPTBigCode & Falcon:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
and getattr(self.hf_config, "new_decoder_architecture", False))
if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False):
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
if getattr(self.hf_config, "num_kv_heads", None) is not None:
return (self.hf_config.num_kv_heads //
parallel_config.tensor_parallel_size)
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
parallel_config.tensor_parallel_size)
total_num_attention_heads = self.hf_config.num_attention_heads
return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: str,
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
invalidInputError(False, f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
return torch_dtype
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if derived_max_model_len == float("inf"):
if max_model_len is not None:
# If max_model_len is specified, we use it.
return max_model_len
default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "
"keys to determine the original maximum length of the model: "
f"{possible_keys}. Assuming the model's maximum length is "
f"{default_max_len}.")
derived_max_model_len = default_max_len
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
invalidInputError("factor" in rope_scaling,
"invalid hf_config value for rope_scaling")
scaling_factor = rope_scaling["factor"]
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
invalidInputError(
False,
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return int(max_model_len)
def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
# NOTE: Because the Mistral model in HF hub does not have
# `configuration_mistral.py`, we cannot use `AutoConfig` to load the
# config. Instead, we use `MistralConfig` directly.
# NOTE: This is a hack. This does not work for local models.
# FIXME: Remove this once the Mistral model is available in the stable
# version of HF transformers.
if "mistral" in model.lower():
return MistralConfig.from_pretrained(model, revision=revision)
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
invalidInputError(err_msg)
else:
invalidInputError(e)
return config
class ParallelConfig:
"""Configuration for the distributed execution.
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
"""
def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
self._verify_args()
def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
invalidInputError(
"Pipeline parallelism is not supported yet.")
class SchedulerConfig:
"""Scheduler configuration.
Args:
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
iteration.
max_model_len: Maximum length of a sequence (including prompt
and generated text).
"""
def __init__(
self,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self._verify_args()
def _verify_args(self) -> None:
if self.max_num_batched_tokens < self.max_model_len:
invalidInputError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
f"smaller than max_model_len ({self.max_model_len}). "
"This effectively limits the maximum sequence length to "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
if self.max_num_batched_tokens < self.max_num_seqs:
invalidInputError(
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
"be greater than or equal to max_num_seqs "
f"({self.max_num_seqs}).")

View file

@ -0,0 +1,79 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/core/policy.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from bigdl.llm.vllm.sequence import SequenceGroup
from bigdl.llm.utils.common import invalidInputError
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
invalidInputError(False, "base class not implemented")
def sort_by_priority(
self,
now: float,
seq_groups: List[SequenceGroup],
) -> List[SequenceGroup]:
return sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
)
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {
'fcfs': FCFS,
}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)

View file

@ -0,0 +1,267 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/core/scheduler.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bigdl-llm Intel specified code change
#
import time
from typing import Dict, Iterable, List, Optional, Tuple, Union
from bigdl.llm.vllm.config import SchedulerConfig
from bigdl.llm.vllm.core.policy import PolicyFactory
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.vllm.sequence import SequenceData, SequenceStatus
from bigdl.llm.vllm.sequence import (Sequence, SequenceGroup,
SequenceGroupMetadata)
from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__)
class SchedulerOutputs:
def __init__(
self,
scheduled_seq_groups: List[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
ignored_seq_groups: List[SequenceGroup],
finished_seqs: List[int],
) -> None:
# bigdl-llm change start
# Summary: we are removing block table related arguments
# We also added finished_seqs so that workers know which sequences
# can be safely deleted
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.ignored_seq_groups = ignored_seq_groups
self.finished_seqs = finished_seqs
# bigdl-llm change end
def is_empty(self) -> bool:
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.finished_seqs)
class FixedWindowScheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
) -> None:
self.scheduler_config = scheduler_config
self.prompt_limit = min(self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
# bigdl-llm change start
# summary: cache_config is removed as we are not implementing the pagetable structure
# block manager is also deleted based on the same reasoning.
# As we are not using the pagetable, the swapped area is also deleted because
# we cannot decide if there is enough memory or not.
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
# Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
self.cleaned: List[int] = []
# Co(gc): We no longer have the swapped space as we are not deciding which to swap
# bigdl-llm change end
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
for state_queue in [self.waiting, self.running]:
# We need to reverse the list as we are removing elements
# from it as we iterate over it. If we don't do it,
# indices will get messed up and we will skip over elements.
for seq_group in reversed(state_queue):
if seq_group.request_id in request_ids:
# Remove the sequence group from the state queue.
state_queue.remove(seq_group)
for seq in seq_group.get_seqs():
if seq.is_finished():
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
request_ids.remove(seq_group.request_id)
if not request_ids:
return
def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running)
def _schedule(self) -> SchedulerOutputs:
# Fix the current time.
now = time.monotonic()
ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = []
finished_seqs: List[int] = self.cleaned.copy()
self.cleaned = []
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
for seq_group in self.running)
num_batched_tokens = 0
if self.waiting:
# We restrict how many requests that can be run using these three arguments
# Co(gc): If there are waiting requests, we will just try to add it into the
# running state if not exceeds the stage
# Co(gc): prefilled requests are prioritized over decoding stage requests
while self.waiting:
seq_group = self.waiting[0]
invalidInputError(seq_group.num_seqs() == 1,
"Waiting sequence group should have only one prompt "
"sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
if num_prompt_tokens > self.prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
continue
# bigdl-llm change start
# summary: removing block_manager related logic.
# TODO(gc): If you can manage to make block_manager work,
# then this will be fine.
# If the sequence group cannot be allocated, stop.
# if not self.block_manager.can_allocate(seq_group):
# break
# bigdl-llm change end
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.get_max_num_running_seqs()
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
# bigdl-llm change start
# summary: removing block_manager related logic.
# self._allocate(seq_group)
# bigdl-llm change end
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
num_curr_seqs += num_new_seqs
scheduled.append(seq_group)
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
ignored_seq_groups=ignored_seq_groups,
finished_seqs=finished_seqs,
)
return scheduler_outputs
# Now consider all the requests in decoding stage
self.running = self.policy.sort_by_priority(now, self.running)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
ignored_seq_groups=[],
finished_seqs=finished_seqs,
)
return scheduler_outputs
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_outputs = self._schedule()
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in scheduler_outputs.scheduled_seq_groups:
seq_data: Dict[int, List[SequenceData]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=scheduler_outputs.prompt_run,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
def free_seq(self, seq: Sequence) -> None:
# bigdl-llm specified change
# bigdl-llm change start
# summary: The original code free the block in block_manager.
# now, we added it into a list to pass to worker in the next model_execute stage.
self.cleaned.append(seq.seq_id)
# bigdl-llm change end
def free_finished_seq_groups(self) -> None:
self.running = [
seq_group for seq_group in self.running
if not seq_group.is_finished()
]

View file

@ -0,0 +1,14 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,267 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/engine/arg_utils.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bigdl-llm Intel specified code change
#
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from bigdl.llm.vllm.config import ModelConfig, SchedulerConfig
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
dtype: str = 'auto'
seed: int = 0
max_model_len: Optional[int] = None
# bigdl-llm change start
# summary: disable model parallel, tensor parallel
# worker_use_ray: bool = False
# pipeline_parallel_size: int = 1
# tensor_parallel_size: int = 1
# bigdl-llm change end
block_size: int = 16
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
# bigdl-llm change start
# summary: add device option
device: Optional[str] = 'cpu'
# bigdl-llm change end
def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
parser.add_argument(
'--model',
type=str,
default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument(
'--tokenizer',
type=str,
default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use')
parser.add_argument(
'--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--tokenizer-revision',
type=str,
default=None,
help='the specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow'],
help='tokenizer mode. "auto" will use the fast '
'tokenizer if available, and "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument('--download-dir',
type=str,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available. '
'"pt" will load the weights in the pytorch bin format. '
'"safetensors" will load the weights in the safetensors format. '
'"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--max-model-len',
type=int,
default=None,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# disable Parallel arguments
# parser.add_argument('--worker-use-ray',
# action='store_true',
# help='use Ray for distributed serving, will be '
# 'automatically set when using more than 1 GPU')
# parser.add_argument('--pipeline-parallel-size',
# '-pp',
# type=int,
# default=EngineArgs.pipeline_parallel_size,
# help='number of pipeline stages')
# parser.add_argument('--tensor-parallel-size',
# '-tp',
# type=int,
# default=EngineArgs.tensor_parallel_size,
# help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='random seed')
parser.add_argument('--swap-space',
type=int,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
parser.add_argument('--device',
type=str,
choices=['gpu', 'cpu', 'xpu', None],
default=None,
help='Device to execute LLM model')
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args
def create_engine_configs(self, ) -> Tuple[ModelConfig, SchedulerConfig]:
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization, self.device)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len)
# bigdl-llm change start
# summary: remove parallel config + cache config
# parallel_config = ParallelConfig(self.pipeline_parallel_size,
# self.tensor_parallel_size, False)
# bigdl-llm change end
return model_config, scheduler_config
@dataclass
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
# bigdl-llm change start
# summary: disable ray
# engine_use_ray: bool = False
# bigdl-llm change end
disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--disable-log-requests',
action='store_true',
help='disable logging requests')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='max number of prompt characters or prompt '
'ID numbers being printed in log. '
'Default: unlimited.')
return parser

View file

@ -0,0 +1,537 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/engine/async_llm_engine.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bigdl-llm Intel specified code change
#
import asyncio
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
from bigdl.llm.vllm.config import ModelConfig
from bigdl.llm.vllm.engine.arg_utils import AsyncEngineArgs
from bigdl.llm.vllm.engine.llm_engine import LLMEngine
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.vllm.outputs import RequestOutput
from bigdl.llm.vllm.sampling_params import SamplingParams
from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__)
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._finished = False
def put(self, item: RequestOutput) -> None:
if self._finished:
return
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopIteration)
self._finished = True
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
result = await self._queue.get()
if result is StopIteration:
raise StopAsyncIteration # noqa
elif isinstance(result, Exception):
raise result # noqa
return result
def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
invalidInputError(
False, msg + " See stack trace above for the actual cause.")
invalidInputError(False, msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
invalidInputError(False, "")
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._request_streams
def init_event(self):
self.new_requests_event = asyncio.Event()
def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self,
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id, verbose=verbose)
def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
invalidInputError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
}))
self.new_requests_event.set()
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
if scheduler_outputs.is_empty():
return ignored
# Execute the model.
# Co(gc): Now that we do not have page table support, we need to pass the
# list of sequences that have been finished so that we can clean the KVCache.
# bigdl-llm change start
# summary: this is the interface between the upper layer and the lower layer.
# we are adding the finished_seqs to lower model.
output = await self._run_workers_async(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in={},
blocks_to_swap_out={},
blocks_to_copy={},
finished_seqs=scheduler_outputs.finished_seqs,
)
# bigdl-llm change end
return self._process_model_outputs(output, scheduler_outputs) + ignored
async def _run_workers_async(
self,
method: str,
*args,
get_all_outputs: bool=False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
# bigdl-llm change start
all_outputs = []
for worker in self.workers:
# if self.parallel_config.worker_use_ray:
# executor = partial(worker.execute_method.remote, method)
# else:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
all_outputs.append(output)
# if self.parallel_config.worker_use_ray:
# all_outputs = await asyncio.gather(*all_outputs)
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
invalidInputError(output == other_output,
"all workers do not have the same result")
return output
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine.
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self,
# worker_use_ray: bool,
# engine_use_ray: bool,
*args,
log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True,
**kwargs) -> None:
# self.worker_use_ray = worker_use_ray
# self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.background_loop = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.is_running:
invalidInputError(False, "Background loop is already running.")
self._request_tracker.init_event()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self._request_tracker))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args, **kwargs) -> _AsyncLLMEngine:
# Co(gc): we disable ray here
# if not self.engine_use_ray:
engine_class = self._engine_class
# elif self.worker_use_ray:
# engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
# else:
# engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self._request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
# if self.engine_use_ray:
# await self.engine.add_request.remote(**new_request)
# else:
self.engine.add_request(**new_request)
if finished_requests:
await self._engine_abort(finished_requests)
# if self.engine_use_ray:
# request_outputs = await self.engine.step.remote()
# else:
request_outputs = await self.engine.step_async()
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
return len(request_outputs) > 0
async def _engine_abort(self, request_ids: Iterable[str]):
# if self.engine_use_ray:
# await self.engine.abort_request.remote(request_ids)
# else:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
await asyncio.sleep(0)
async def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]]=None,
arrival_time: Optional[float]=None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"sampling params: {sampling_params}, "
f"prompt token ids: {shortened_token_ids}.")
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
else:
invalidInputError(
False,
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
return stream
async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]]=None) -> RequestOutput:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
request into the waiting queue of the LLMEngine and streams the outputs
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
request.
"""
# Preprocess the request.
# This should not be used for logging, as it is monotonic time.
arrival_time = time.monotonic()
try:
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
invalidInputError(False, "exception occurred")
async def abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
if not self.is_running:
invalidInputError(
False,
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
return self._abort(request_id)
def _abort(self, request_id: str) -> None:
"""Abort a request.
Abort a submitted request. If the request is finished or not found,
this method will be a no-op.
Args:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
# if self.engine_use_ray:
# return await self.engine.get_model_config.remote()
# else:
return self.engine.get_model_config()
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
# bigdl-llm code change start
# summary: remove unrelated configs from the code
# parallel_config = engine_configs[2]
# Initialize the cluster.
# port = get_open_port()
# distributed_init_method = f"tcp://localhost:{port}"
# Create the async LLM engine.
engine = cls(
# engine_args.worker_use_ray,
# engine_args.engine_use_ray,
# TODO: we use one less here
*engine_configs,
# distributed_init_method,
# None,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
# bigdl-llm code change end
return engine

View file

@ -0,0 +1,718 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/engine/llm_engine.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bigdl-llm Intel specified code change
#
import time
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from bigdl.llm.vllm.config import ModelConfig, SchedulerConfig
from bigdl.llm.vllm.core.scheduler import SchedulerOutputs, FixedWindowScheduler
from bigdl.llm.vllm.engine.arg_utils import EngineArgs
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.vllm.outputs import RequestOutput
from bigdl.llm.vllm.sampling_params import SamplingParams
from bigdl.llm.vllm.sequence import (
SamplerOutput,
Sequence,
SequenceGroup,
SequenceGroupMetadata,
SequenceStatus,
SequenceOutputs,
)
from bigdl.llm.vllm.transformers_utils.tokenizer import get_tokenizer, detokenize_incrementally
from bigdl.llm.vllm.utils import (
Counter,
)
from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
This is the main class for the vLLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
"""
def __init__(
self,
model_config: ModelConfig,
# parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
# distributed_init_method: str,
# placement_group,
log_stats: bool,
) -> None:
# bigdl-llm change start
# summary: removing parallel_config and related checks.
# distributed_init_method/placement_group is related to these configs
# so they are removed too.
logger.info(
"Initializing an LLM engine with config: "
f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"tokenizer_revision={model_config.tokenizer_revision}, "
f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, "
f"max_seq_len={model_config.max_model_len}, "
f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, "
# f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})"
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
# self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
# self._verify_args()
self.tokenizer = get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
tokenizer_revision=model_config.tokenizer_revision,
revision=model_config.revision,
)
self.seq_counter = Counter()
# Create the parallel GPU workers.
self._init_workers()
# Co(gc): we create a fixed scheduler
self.scheduler = FixedWindowScheduler(scheduler_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
# bigdl-llm change end
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from bigdl.llm.vllm.worker.worker import (
Worker,
) # pylint: disable=import-outside-toplevel
# invalidInputError(
# self.parallel_config.world_size == 1,
# "Ray is required if parallel_config.world_size > 1.",
# )
self.workers: List[Worker] = []
worker = Worker(
self.model_config,
self.scheduler_config,
0,
# distributed_init_method,
)
self.workers.append(worker)
self._run_workers(
"init_model",
get_all_outputs=True,
)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
# Co(gc): this simply checks if the swap is too large or not
# self.cache_config.verify_with_parallel_config(self.parallel_config)
@classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# bigdl-llm change start
# summary: remove parallel_config and related settings.
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
# parallel_config = engine_configs[2]
# Initialize cluster locally.
# port = get_open_port()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
# distributed_init_method = f"tcp://localhost:{port}"
# Create the LLM engine.
engine = cls(
*engine_configs,
# distributed_init_method,
# None,
log_stats=not engine_args.disable_log_stats,
)
return engine
def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
"""
if arrival_time is None:
arrival_time = time.monotonic()
if prompt_token_ids is None:
invalidInputError(prompt is not None, "Prompt should not be None")
prompt_token_ids = self.tokenizer.encode(prompt)
# Create the sequences.
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids)
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
Args:
request_id: The ID(s) of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
def _schedule(
self,
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[RequestOutput]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
return (
seq_group_metadata_list,
scheduler_outputs,
[
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
],
)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
invalidInputError(sampling_params.use_beam_search, "Should be beam_search")
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
)
else:
invalidInputError(
early_stopping == "never", "early_stopping should be never"
)
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() + sampling_params.max_tokens,
self.scheduler_config.max_model_len,
)
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
seq_len=max_possible_length,
)
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=self.tokenizer.eos_token_id,
)
return current_worst_score >= highest_attainable_score
def _process_sequence_group_samples(
self, seq_group: SequenceGroup, samples: List[SequenceOutputs]
) -> None:
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs}
# Co(gc):parent_child_dict = {seq_id: [SampleOutputs]}
for sample in samples:
parent_child_dict[sample.parent_seq_id].append(sample)
# List of (child, parent)
child_seqs: List[Tuple[Sequence, Sequence]] = []
# Process the child samples for each parent sequence
# Co(gc): For each child samples, create a sequence, and add it the child_seqs
for parent in parent_seqs:
# Co(gc): Get all the child_samples, SequenceOuptuts
child_samples: List[SequenceOutputs] = parent_child_dict[parent.seq_id]
# We do not have any SequenceOutputs
if len(child_samples) == 0:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent.status = SequenceStatus.FINISHED_ABORTED
seq_group.remove(parent.seq_id)
self.scheduler.free_seq(parent)
continue
# Fork the parent sequence if there are multiple child samples.
# Co(gc): The outputs diverges, we need to fork the requests
for child_sample in child_samples[:-1]:
new_child_seq_id = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(
child_sample.output_token,
child_sample.logprobs,
child_sample.latency,
)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(
last_child_sample.output_token,
last_child_sample.logprobs,
last_child_sample.latency,
)
child_seqs.append((parent, parent))
for seq, _ in child_seqs:
self._decode_sequence(seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params)
# Non-beam search case
if not seq_group.sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
pass
# bigdl-llm change start
# summary: fork_seq is doing some block manager ops, so we remove this
# self.scheduler.fork_seq(parent, seq)
# bigdl-llm change end
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs = []
unselected_child_seqs = []
beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs]
new_finished_seqs = [
(seq, parent, True) for seq, parent in child_seqs if seq.is_finished()
]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(
key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
),
reverse=True,
)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [
(seq, parent) for seq, parent in child_seqs if not seq.is_finished()
]
# Sort the running sequences by their scores.
running_child_seqs.sort(
key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id
),
reverse=True,
)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
seq_group.sampling_params.early_stopping,
seq_group.sampling_params,
best_running_seq,
current_worst_seq,
)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
pass
# bigdl-llm change start
# summary: fork_seq is doing some block manager ops, so we remove this
# self.scheduler.fork_seq(parent, seq)
# bigdl-llm change end
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
if seq is parent and seq.is_finished():
self.scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
self.scheduler.free_seq(seq)
def _process_model_outputs(
self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs
) -> List[RequestOutput]:
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
for seq_group, samples in zip(scheduled_seq_groups, output):
self._process_sequence_group_samples(seq_group, samples)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
# Co(gc): we disable the gpu stats part in the function
if self.log_stats:
# Log the system stats.
self._log_system_stats(
scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens
)
return request_outputs
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule()
if scheduler_outputs.is_empty():
return ignored
# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in={},
blocks_to_swap_out={},
blocks_to_copy={},
finished_seqs=scheduler_outputs.finished_seqs,
)
return self._process_model_outputs(output, scheduler_outputs) + ignored
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
now = time.monotonic()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time < _LOGGING_INTERVAL_SEC:
return
# Discard the old stats.
self.num_prompt_tokens = [
(t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC
]
self.num_generation_tokens = [
(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC
]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
# bigdl-llm change start
# summary: removing logging of pagetable related arguments
# total_num_gpu_blocks = self.cache_config.num_gpu_blocks
# num_free_gpu_blocks = (
# self.scheduler.block_manager.get_num_free_gpu_blocks())
# num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
# gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
# total_num_cpu_blocks = self.cache_config.num_cpu_blocks
# if total_num_cpu_blocks > 0:
# num_free_cpu_blocks = (
# self.scheduler.block_manager.get_num_free_cpu_blocks())
# num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
# cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
# else:
# cpu_cache_usage = 0.0
# bigdl-llm change end
logger.info(
"Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
)
self.last_logging_time = now
def _decode_sequence(self, seq: Sequence, sampling_params: SamplingParams) -> None:
"""Decodes the new token for a sequence."""
(
new_tokens,
new_output_text,
prefix_offset,
read_offset,
) = detokenize_incrementally(
self.tokenizer,
all_input_ids=seq.get_token_ids(),
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=sampling_params.skip_special_tokens,
)
if seq.tokens is None:
seq.tokens = new_tokens
else:
seq.tokens.extend(new_tokens)
seq.prefix_offset = prefix_offset
seq.read_offset = read_offset
seq.output_text += new_output_text
def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[: -len(stop_str)]
seq.status = SequenceStatus.FINISHED_STOPPED
return
if seq.get_last_token_id() in sampling_params.stop_token_ids:
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return
# Check if the sequence has generated the EOS token.
if (
not sampling_params.ignore_eos
) and seq.get_last_token_id() == self.tokenizer.eos_token_id:
seq.status = SequenceStatus.FINISHED_STOPPED
return
def _run_workers(
self,
method: str,
*args,
get_all_outputs: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
all_outputs = []
for worker in self.workers:
# bigdl-llm change start
# summary: we disable ray here
# if self.parallel_config.worker_use_ray:
# executor = partial(worker.execute_method.remote, method)
# else:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
all_outputs.append(output)
# if self.parallel_config.worker_use_ray:
# all_outputs = ray.get(all_outputs)
# bigdl-llm change end
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
invalidInputError(
output == other_output, "All workers should have same output"
)
return output

View file

@ -0,0 +1,14 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,235 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/entrypoints/llm.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bigdl-llm Intel specified code change
#
from typing import List, Optional, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from bigdl.llm.vllm.engine.arg_utils import EngineArgs
from bigdl.llm.vllm.engine.llm_engine import LLMEngine
from bigdl.llm.vllm.outputs import RequestOutput
from bigdl.llm.vllm.sampling_params import SamplingParams
from bigdl.llm.vllm.utils import Counter
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
device: The device to be used for the model. If None, we will default
to use CPU as the device.
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
# bigdl-llm change start
# summary: add device option
device: Optional[str] = "cpu",
# bigdl-llm change end
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
# bigdl-llm change start
# summary: disable parallel config.
# tensor_parallel_size=tensor_parallel_size,
# bigdl-llm change end
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
device=device,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.request_counter = Counter()
def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer = tokenizer
def generate(
self,
prompts: Optional[Union[str, List[str]]]=None,
sampling_params: Optional[SamplingParams]=None,
prompt_token_ids: Optional[List[List[int]]]=None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be " "provided.") # noqa
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None and prompt_token_ids is not None:
if len(prompts) != len(prompt_token_ids):
raise ValueError( # noqa
"The lengths of prompts and prompt_token_ids " "must be the same."
)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the engine.
if prompts is not None:
num_requests = len(prompts)
else:
num_requests = len(prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
return self._run_engine(use_tqdm)
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id, prompt, sampling_params, prompt_token_ids
)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs

View file

@ -0,0 +1,14 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,733 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/entrypoints/openai/api_server.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
#
# bigdl-llm Intel specified code change
#
import argparse
import asyncio
import json
import time
from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
import fastapi
import uvicorn
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from packaging import version
import numpy as np
from bigdl.llm.vllm.engine.arg_utils import AsyncEngineArgs
from bigdl.llm.vllm.engine.async_llm_engine import AsyncLLMEngine
from bigdl.llm.vllm.entrypoints.openai.protocol import (
CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage,
DeltaMessage, ErrorResponse, LogProbs, ModelCard, ModelPermission,
UsageInfo)
from bigdl.llm.vllm.entrypoints.openai.openai_protocol import (
CompletionRequest, ChatCompletionRequest,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
ModelList)
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.vllm.outputs import RequestOutput
from bigdl.llm.vllm.sampling_params import SamplingParams
from bigdl.llm.vllm.transformers_utils.tokenizer import get_tokenizer
import uuid
from bigdl.llm.utils.common import invalidInputError
try:
import fastchat
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
_fastchat_available = True
except ImportError:
_fastchat_available = False
TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()
engine = None
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(ErrorResponse(message=message,
type="invalid_request_error").dict(),
status_code=status_code.value)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret
async def get_gen_prompt(request) -> str:
if not _fastchat_available:
invalidInputError(
False,
"fastchat is not installed. Please install fastchat to use "
"the chat completion and conversation APIs: `$ pip install fschat`"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
invalidInputError(
False,
f"fastchat version is low. Current version: {fastchat.__version__} "
"Please upgrade fastchat to use: `$ pip install -U fschat`")
conv = get_conversation_template(request.model)
conv = Conversation(
name=conv.name,
system_template=conv.system_template,
system_message=conv.system_message,
roles=conv.roles,
messages=list(conv.messages), # prevent in-place modification
offset=conv.offset,
sep_style=SeparatorStyle(conv.sep_style),
sep=conv.sep,
sep2=conv.sep2,
stop_str=conv.stop_str,
stop_token_ids=conv.stop_token_ids,
)
if isinstance(request.messages, str):
prompt = request.messages
else:
for message in request.messages:
msg_role = message["role"]
if msg_role == "system":
conv.system_message = message["content"]
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
invalidInputError(False, f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
return prompt
async def check_length(
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str]=None,
prompt_ids: Optional[List[int]]=None
) -> Tuple[List[int], Optional[JSONResponse]]:
invalidInputError((not (prompt is None and prompt_ids is None)
and not (prompt is not None and prompt_ids is not None)
), "Either prompt or prompt_ids should be provided.")
if prompt_ids is not None:
input_ids = prompt_ids
else:
input_ids = tokenizer(prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = max_model_len - token_num
if token_num + request.max_tokens > max_model_len:
return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.",
)
else:
return input_ids, None
@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model,
root=served_model,
permission=[ModelPermission()])
]
return ModelList(data=model_cards)
def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
logprobs.top_logprobs.append({
tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()
})
return logprobs
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (to be supported by vLLM engine)
"""
logger.info(f"Received chat completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
prompt = await get_gen_prompt(request)
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
try:
sampling_params = SamplingParams(
n=request.n,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
max_tokens=request.max_tokens,
best_of=request.best_of,
top_k=request.top_k,
ignore_eos=request.ignore_eos,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
output_token_latency: Optional[List[float]] = None,
) -> str:
if output_token_latency is None:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
)
elif len(output_token_latency) == 1:
# bigdl-llm change start
# summary: add token-time recording related logic
# other modifications follow the same logic
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
first_token_time=output_token_latency[0])
else:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
finish_reason=finish_reason,
first_token_time=output_token_latency[0],
rest_token_time=np.mean(output_token_latency[1:]))
# bigdl-llm change end
response = ChatCompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
# First chunk with role
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
output_token_latency=None,
)
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
output_token_latency=output.output_token_latency,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
response_json = create_stream_response_json(
index=i,
text="",
finish_reason=output.finish_reason,
output_token_latency=output.output_token_latency,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if request.stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
invalidInputError(final_res is not None, "final result should not be None")
choices = []
for output in final_res.outputs:
if output.output_token_latency is None:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
)
else:
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role="assistant", content=output.text),
finish_reason=output.finish_reason,
first_token_time=output.output_token_latency[0],
rest_token_time=np.mean(output.output_token_latency[1:]),
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following features:
- echo (since the vLLM engine does not currently support
getting the logprobs of prompt tokens)
- suffix (the language models we currently support do not support
suffix)
- logit_bias (to be supported by vLLM engine)
"""
logger.info(f"Received completion request: {request}")
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
if request.echo:
# We do not support echo since the vLLM engine does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")
if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")
if request.logit_bias is not None and len(request.logit_bias) > 0:
# TODO: support logit_bias in vLLM engine.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
use_token_ids = False
if isinstance(request.prompt, list):
if len(request.prompt) == 0:
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
first_element = request.prompt[0]
if isinstance(first_element, int):
use_token_ids = True
prompt = request.prompt
elif isinstance(first_element, (str, list)):
# TODO: handles multiple prompt case in list[list[int]]
if len(request.prompt) > 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
use_token_ids = not isinstance(first_element, str)
prompt = request.prompt[0]
else:
prompt = request.prompt
if use_token_ids:
_, error_check_ret = await check_length(request, prompt_ids=prompt)
else:
token_ids, error_check_ret = await check_length(request, prompt=prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.monotonic())
try:
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
stop_token_ids=request.stop_token_ids,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
skip_special_tokens=request.skip_special_tokens,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
if use_token_ids:
result_generator = engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=prompt)
else:
result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
def create_stream_response_json(
index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None,
output_token_latency: Optional[List[float]] = None,
) -> str:
if output_token_latency is None:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
elif len(output_token_latency) == 1:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
first_token_time=output_token_latency[0])
else:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
first_token_time=output_token_latency[0],
rest_token_time=np.mean(output_token_latency[1:]),
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)
return response_json
async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
output_token_latency=output.output_token_latency,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = (LogProbs()
if request.logprobs is not None else None)
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
output_token_latency=output.output_token_latency,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
# Streaming response
if stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")
# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return create_error_response(HTTPStatus.BAD_REQUEST,
"Client disconnected")
final_res = res
invalidInputError(final_res is not None, "final result should not be None")
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
else:
logprobs = None
if output.output_token_latency is None:
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
else:
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
first_token_time=output.output_token_latency[0],
rest_token_time=np.mean(output.output_token_latency[1:]),
)
choices.append(choice_data)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")
return response
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
logger.info(f"args: {args}")
if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())
max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
tokenizer_mode=engine_args.tokenizer_mode,
trust_remote_code=engine_args.trust_remote_code)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

View file

@ -0,0 +1,229 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/entrypoints/openai/protocol.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
#
# bigdl-llm Intel specified code change
#
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from bigdl.llm.vllm.utils import random_uuid
# bigdl-llm change start
# summary: add token time recording logic
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: Optional[str] = None
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
class CompletionRequest(BaseModel):
model: str
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Optional[Dict[str,
float]]] = Field(default_factory=list)
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
# bigdl-llm change end
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]

View file

@ -0,0 +1,103 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/entrypoints/openai/protocol.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
#
# bigdl-llm Intel specified code change
#
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from bigdl.llm.vllm.utils import random_uuid
from bigdl.llm.vllm.entrypoints.openai.openai_protocol import (
ErrorResponse, ModelPermission, ModelCard, UsageInfo, LogProbs, ChatMessage, DeltaMessage
)
# bigdl-llm change start
# summary: add token time recording logic
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length"]] = None
first_token_time: Optional[float] = None
rest_token_time: Optional[float] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
# bigdl-llm change end

View file

@ -0,0 +1,81 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/logger.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
class NewLineFormatter(logging.Formatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None):
logging.Formatter.__init__(self, fmt, datefmt)
def format(self, record):
msg = logging.Formatter.format(self, record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
_root_logger = logging.getLogger("vllm")
_default_handler = None
def _setup_logger():
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()
def init_logger(name: str):
return logging.getLogger(name)

View file

@ -0,0 +1,21 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from bigdl.llm.vllm.model_executor.model_loader import get_model
__all__ = [
"get_model",
]

View file

@ -0,0 +1,91 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/model_executor/input_metadata.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple
import torch
from xformers.ops import AttentionBias
from bigdl.llm.vllm.sequence import SequenceData
from bigdl.llm.vllm.sampling_params import SamplingParams
class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
context_lens: torch.Tensor,
max_context_len: int,
sliding_window: Optional[int] = None,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.context_lens = context_lens
self.max_context_len = max_context_len
self.to_cache = None
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
# Set during the execution of the first attention op.
# TODO(gc): we might want to delete this
self.attn_bias: List[AttentionBias] = []
def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
f'num_valid_tokens={self.num_valid_tokens}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')

View file

@ -0,0 +1,544 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/model_executor/layers/sampler.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# bigdl-llm Intel specified code change
#
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
from bigdl.llm.vllm.sampling_params import SamplingParams, SamplingType
from bigdl.llm.vllm.sequence import (SamplerOutput, SequenceGroupMetadata,
SequenceData, SequenceOutputs)
import time
_SAMPLING_EPS = 1e-5
# Summary: Remove method for GPU distributed.
# bigdl-llm change start
# def tensor_model_parallel_all_gather(input_, dim=-1):
# """All-gather the input tensor across model parallel group."""
# world_size = get_tensor_model_parallel_world_size()
# # Bypass the function if we are using only 1 GPU.
# if world_size == 1:
# return input_
# if dim < 0:
# # Convert negative dim to positive.
# dim += input_.dim()
# input_size = input_.size()
# # Allocate output tensor.
# output_tensor = torch.empty((world_size, ) + input_size,
# dtype=input_.dtype,
# device=input_.device)
# # All-gather.
# torch.distributed.all_gather_into_tensor(
# output_tensor, input_, group=get_tensor_model_parallel_group())
# # Reshape
# output_tensor = output_tensor.movedim(0, dim)
# output_tensor = output_tensor.reshape(input_size[:dim] +
# (world_size * input_size[dim], ) +
# input_size[dim + 1:])
# return output_tensor
# bigdl-llm change end
class BigDLSampler(nn.Module):
"""Samples the next tokens from the model's outputs.
This layer does the following:
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Here, each sequence group within the batch can have different sampling
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def __init__(self, vocab_size: int, device: Optional[str] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
if device == 'xpu':
import intel_extension_for_pytorch as ipex
def forward(
self,
logits: torch.Tensor,
input_metadata: InputMetadata,
st_timestamp: Optional[float] = None,
) -> SamplerOutput:
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures,
dtype=logits.dtype,
device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
logits = _apply_top_p_top_k(logits, top_ps, top_ks)
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata, st_timestamp)
# def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
# embedding_bias: Optional[torch.Tensor],
# vocab_size: int) -> torch.Tensor:
# # Get the logits for the next tokens.
# logits = torch.matmul(hidden_states, embedding.t())
# if embedding_bias is not None:
# logits += embedding_bias
# logits = tensor_model_parallel_all_gather(logits)
# # Remove paddings in vocab (if any).
# logits = logits[:, :vocab_size]
# return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
last_token_indices = {t: [] for t in SamplingType}
start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
if i < input_metadata.num_prompts:
prompt_len = input_metadata.prompt_lens[i]
last_token_indices[sampling_type].append(start_idx + prompt_len -
1)
start_idx += prompt_len
else:
num_seqs = len(seq_ids)
last_token_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs
all_last_token_indices = []
for sampling_type in SamplingType:
all_last_token_indices.extend(last_token_indices[sampling_type])
all_last_token_indices = torch.tensor(all_last_token_indices,
dtype=torch.long,
device=hidden_states.device)
return hidden_states.index_select(0, all_last_token_indices)
def _get_penalties(
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
output_tokens: List[List[int]] = []
for seq_group in input_metadata.seq_groups:
seq_ids, _ = seq_group
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
output_tokens.append(seq_data.output_token_ids)
return output_tokens
def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
presence_penalties: List[float],
frequency_penalties: List[float],
) -> torch.Tensor:
num_seqs, vocab_size = logits.shape
for i in range(num_seqs):
if not output_tokens[i]:
continue
p = presence_penalties[i]
f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
continue
break
else:
# Return early if all sequences have zero penalties.
return logits
max_output_len = max(len(tokens) for tokens in output_tokens)
padded_output_tokens = [
tokens + [vocab_size] * (max_output_len - len(tokens))
for tokens in output_tokens
]
output_tokens_tensor = torch.tensor(padded_output_tokens,
dtype=torch.long,
device=logits.device)
# Compute the bin counts for the output tokens.
# vocab_size + 1 for padding.
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
dtype=torch.long,
device=logits.device)
bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
return logits
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_p_top_k(
input_metadata: InputMetadata,
vocab_size: int,
) -> Tuple[List[float], List[int]]:
top_ps: List[float] = []
top_ks: List[int] = []
for seq_group in input_metadata.seq_groups:
seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
return top_ps, top_ks
def _apply_top_p_top_k(
logits: torch.Tensor,
top_ps: List[float],
top_ks: List[int],
) -> torch.Tensor:
p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
logits_sort[top_p_mask] = -float("inf")
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
logits_sort[top_k_mask] = -float("inf")
# Re-sort the probabilities.
logits = torch.gather(logits_sort,
dim=-1,
index=torch.argsort(logits_idx, dim=-1))
return logits
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: Optional[int],
) -> List[Dict[int, float]]:
num_seqs = logprobs.size(0)
if num_logprobs is None or num_logprobs == 0:
return [{} for _ in range(num_seqs)]
all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
num_logprobs,
dim=-1)
all_topk_logprobs = all_topk_logprobs.cpu()
all_topk_ids = all_topk_ids.cpu()
all_token_to_logprob = []
for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id.item()] = logprob.item()
all_token_to_logprob.append(token_to_logprob)
return all_token_to_logprob
def _build_sequence_outputs(
parent_ids: List[int],
next_token_ids: List[int],
selected_token_logprobs: List[float],
parent_seq_ids: List[int],
parent_logprobs: torch.Tensor,
num_output_logprobs: Optional[int],
st_timestamp: Optional[float],
) -> List[SequenceOutputs]:
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
seq_outputs: List[SequenceOutputs] = []
for parent_id, next_token_id, token_logprob in zip(
parent_ids, next_token_ids, selected_token_logprobs):
output_logprobs = next_logprobs[parent_id].copy()
output_logprobs[next_token_id] = token_logprob
ed_timestamp = time.perf_counter()
seq_outputs.append(
SequenceOutputs(parent_seq_ids[parent_id], next_token_id, ed_timestamp - st_timestamp,
output_logprobs))
return seq_outputs
def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
samples = torch.argmax(logprobs, dim=-1).cpu()
sample_idx = 0
results = []
for seq_group in selected_seq_groups:
seq_ids, _ = seq_group
num_parent_seqs = len(seq_ids)
parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx].item()]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _random_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
probs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# Find the maximum best_of value of the prompt phase requests.
max_best_of = 1
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
if is_prompt:
seq_ids, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
random_samples = torch.multinomial(probs,
num_samples=max_best_of,
replacement=True).cpu()
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.best_of
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
next_token_ids = random_samples[sample_idx:sample_idx +
num_parent_seqs, 0].tolist()
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
#
# Note: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
parent_ids = [0] * (2 * beam_width)
_, next_token_ids = torch.topk(seq_group_logprobs[0],
2 * beam_width)
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
vocab_size = seq_group_logprobs.size(-1)
parent_ids = [i // vocab_size for i in topk_ids]
next_token_ids = [i % vocab_size for i in topk_ids]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
st_timestamp: Optional[float] = None,
) -> SamplerOutput:
categorized_seq_group_ids = {t: [] for t in SamplingType}
category_num_tokens = {t: 0 for t in SamplingType}
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
category_num_tokens[sampling_type] += num_seqs
seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
category_start_idx = 0
for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
num_tokens = category_num_tokens[sampling_type]
if num_tokens == 0:
continue
category_logprobs = logprobs[category_start_idx:category_start_idx +
num_tokens]
category_probs = probs[category_start_idx:category_start_idx +
num_tokens]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, category_logprobs)
elif sampling_type == SamplingType.RANDOM:
sample_results = _random_sample(seq_groups, is_prompts,
category_probs)
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
category_logprobs)
# Batched query for logprobs of selected token
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
sample_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices.extend(next_token_ids)
sample_idx += num_parent_seqs
batched_logprobs_query_result = category_logprobs[[
batched_logprobs_query_seq_indices,
batched_logprobs_query_token_indices
]].tolist()
# Build the sequence outputs.
sample_idx = 0
result_idx = 0
for seq_group_id, seq_group, sample_result in zip(
seq_group_ids, seq_groups, sample_results):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_results = len(next_token_ids)
num_parent_seqs = len(seq_ids)
parent_logprobs = category_logprobs[sample_idx:sample_idx +
num_parent_seqs]
selected_token_logprobs = batched_logprobs_query_result[
result_idx:result_idx + num_results]
seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
selected_token_logprobs,
seq_ids, parent_logprobs,
sampling_params.logprobs,
st_timestamp)
seq_outputs_dict[seq_group_id] = seq_output
sample_idx += num_parent_seqs
result_idx += num_results
category_start_idx += num_tokens
return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]

View file

@ -0,0 +1,122 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/model_executor/model_loader.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for selecting and loading models."""
import contextlib
from typing import Type
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from bigdl.llm.vllm.config import ModelConfig
from bigdl.llm.vllm.model_executor.models.bigdl_llama import BigDLLlamaForCausalLM
from bigdl.llm.utils.common import invalidInputError
# bigdl-llm Intel specified code change
# bigdl-llm change start
# summary: Currently we only support LLAMA model and users can add their models by adding
# code in ./models dir and then regitering here.
_MODEL_REGISTRY = {
# "AquilaModel": AquilaForCausalLM,
# "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
# "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
# "BloomForCausalLM": BloomForCausalLM,
# "FalconForCausalLM": FalconForCausalLM,
# "GPT2LMHeadModel": GPT2LMHeadModel,
# "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
# "GPTJForCausalLM": GPTJForCausalLM,
# "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
# "InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": BigDLLlamaForCausalLM,
# "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
# "MistralForCausalLM": MistralForCausalLM,
# "MPTForCausalLM": MPTForCausalLM,
# "OPTForCausalLM": OPTForCausalLM,
# "QWenLMHeadModel": QWenLMHeadModel,
# "RWForCausalLM": FalconForCausalLM,
}
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
# LlamaForCausalLM,
]
# bigdl-llm change end
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
return _MODEL_REGISTRY[arch]
invalidInputError(
False,
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
quant_config = None
if model_config.quantization is not None:
invalidInputError(f"Quantization is not supported for {model_class}.")
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config, device=model_config.device,
max_model_len=model_config.max_model_len)
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format, model_config.revision)
# bigdl-llm Intel specified code change
# bigdl-llm change start
# summary: Only use cuda when device is gpu.
if model_config.device == 'gpu':
model = model.cuda()
# bigdl-llm change end
return model.eval()

View file

@ -0,0 +1,184 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import nn
from transformers import AutoTokenizer, PreTrainedTokenizerBase, LlamaConfig
from typing import Optional, Tuple, List, Type, Dict
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
from bigdl.llm.vllm.model_executor.layers.bigdl_sampler import BigDLSampler
from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausalLM
import math
import time
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]:
return x + [padding_id] * (max_len - len(x))
class BigDLLlamaForCausalLM(BigDLModelForCausalLM):
def __init__(
self,
config: LlamaConfig,
device: Optional[str] = None,
max_model_len: Optional[int] = None,
):
super().__init__(config, device, max_model_len)
self.config = config
# TODO(gc): later change this to a switch?
if True:
from bigdl.llm.transformers import AutoModelForCausalLM
from bigdl.llm import optimize_model
# low_bit = 'sym_int4'
if device == 'cpu':
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
low_cpu_mem_usage=True,
trust_remote_code=True,
use_cache=True,
)
self.model = optimize_model(model)
self.sampler = BigDLSampler(config.vocab_size, device)
elif device == 'xpu':
try:
import intel_extension_for_pytorch as ipex
except ImportError:
print("Intel Extension for PyTorch is not installed, \
but is required for xpu inference.")
low_bit = 'sym_int4'
model = AutoModelForCausalLM.from_pretrained(
config._name_or_path,
load_in_low_bit=low_bit,
trust_remote_code=True,
use_cache=True,
)
self.model = model.to('xpu')
self.sampler = BigDLSampler(config.vocab_size, device).to('xpu')
if device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.dtype = self.model.dtype
self.kv_cache_size = [0]
self.last_seq_ids = []
self.tmp_kv_cache = None
self.pad_token_id = config.pad_token_id
self.max_seq_limit = max_model_len
def forward(
self,
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Optional = None,
input_metadata: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
kv_cache_size_0 = self.model.config.num_hidden_layers
kv_cache_size_1 = 2
seq_len = len(seq_group_meta_data_lists)
bigdl_input_ids = []
bigdl_position_ids = []
bigdl_attention_mask = []
cur_seq_ids = []
bigdl_sampling_params = {}
max_context_len = 0
all_decoding = True
for seq_group_meta_data in seq_group_meta_data_lists:
req_id = seq_group_meta_data.request_id
all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
cur_seq_ids.append(seq_id)
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_seq_input_ids = seq_data.get_token_ids()
context_len = seq_data.get_len()
if seq_group_meta_data.is_prompt:
bigdl_input_ids.append(cur_seq_input_ids)
max_context_len = max(max_context_len, context_len)
else:
bigdl_input_ids.append([cur_seq_input_ids[-1]])
bigdl_sampling_params[seq_id] = seq_group_meta_data.sampling_params
if all_decoding:
bigdl_kv_cache = self.prepare_kv_cache(cur_seq_ids, seq_group_meta_data_lists,
kv_cache, kv_cache_size_0, kv_cache_size_1)
else:
bigdl_input_ids = [
_pad_to_max(input_ids, max_context_len, self.pad_token_id)
for input_ids in bigdl_input_ids
]
if all_decoding:
cur_seq_len = bigdl_kv_cache[0][0].size(2)
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
cur_pos = seq_data.get_len()
bigdl_position_ids.append([cur_pos - 1])
cur_attention_mask = [0] * (cur_seq_len - cur_pos + 1) + [1] * (cur_pos)
bigdl_attention_mask.append(cur_attention_mask)
bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
if all_decoding:
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
bigdl_attention_mask = torch.tensor(bigdl_attention_mask, device=self.device)
kwargs = {
"input_ids": bigdl_input_ids,
"position_ids": bigdl_position_ids,
"attention_mask": bigdl_attention_mask,
"past_key_values": bigdl_kv_cache,
"use_cache": True,
# "return_dict": True,
}
else:
kwargs = {
"input_ids": bigdl_input_ids,
# "position_ids": bigdl_position_ids,
"past_key_values": None,
"use_cache": True,
# "return_dict": True,
}
# pdb.set_trace()
st_timestamp = time.perf_counter()
outputs = self.model.forward(**kwargs)
self.last_seq_ids = cur_seq_ids[:]
self.tmp_kv_cache = outputs.past_key_values
logits = outputs.logits[:, -1, :]
bigdl_output = self.sampler(logits, input_metadata, st_timestamp)
self.update_kv_cache(cur_seq_ids, outputs.past_key_values,
kv_cache, kv_cache_size_0, kv_cache_size_1)
return bigdl_output

View file

@ -0,0 +1,140 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import nn
from typing import Optional, Tuple, List, Type, Dict
from transformers import LlamaConfig
from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata
def _pad_kv_cache_view(t: torch.Tensor, len: int,
device: torch.device, pos: int = 2) -> torch.Tensor:
cur_size = list(t.size())
if cur_size[pos] < len:
tmp_size = cur_size[:]
tmp_size[pos] = len - cur_size[pos]
zeros = torch.zeros(tmp_size, device=device)
padded_view = torch.cat((zeros, t), dim=pos)
return padded_view
if cur_size[pos] > len:
padded_view = t.narrow(pos, cur_size[pos] - len, len)
return padded_view
return t
class BigDLModelForCausalLM(nn.Module):
def __init__(
self,
config,
device: Optional[str] = None,
max_model_len: Optional[int] = None,
):
super().__init__()
self.config = config
if device is None:
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.max_seq_limit = max_model_len
self.last_kv_cache = None
self.last_seq_ids = None
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head).
def prepare_kv_cache(
self,
cur_seq_ids: List[int],
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Dict,
kv_cache_size_0: int,
kv_cache_size_1: int,
):
max_seq_limit = self.max_seq_limit
if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids:
if self.last_kv_cache[0][0].size(2) < max_seq_limit:
bigdl_kv_cache = self.tmp_kv_cache
else:
bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2)
- max_seq_limit, max_seq_limit)
for tmp in tmp_list] for tmp_list in self.last_kv_cache]
else:
bigdl_kv_cache = []
for i in range(kv_cache_size_0):
cur_list = []
for j in range(kv_cache_size_1):
cur_view = None
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
seq_data = seq_group_meta_data.seq_data[seq_id]
view_size = [1] + list(kv_cache[seq_id][i][j].shape)
if cur_view is None:
cur_view = kv_cache[seq_id][i][j].view(view_size)
else:
if cur_view.size(2) != view_size[2]:
max_len = max(cur_view.size(2), view_size[2])
cur_view = _pad_kv_cache_view(cur_view, max_len, self.device)
tmp_view = _pad_kv_cache_view(
kv_cache[seq_id][i][j].view(view_size),
max_len, self.device)
cur_view = torch.cat((cur_view, tmp_view), dim=0)
else:
cur_view = torch.cat(
(cur_view, kv_cache[seq_id][i][j].view(view_size)), dim=0)
if cur_view.size(2) > max_seq_limit:
cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device)
cur_list.append(cur_view)
bigdl_kv_cache.append(cur_list)
return bigdl_kv_cache
# This is an implementation for models that KV Cache shape in (batch_size, num_heads,
# sequence_length, embed_size_per_head).
def update_kv_cache(
self,
cur_seq_ids: List[int],
past_key_values: List[List[torch.Tensor]],
kv_cache: Dict,
kv_cache_size_0: int,
kv_cache_size_1: int,
) -> None:
index = 0
for seq_id in cur_seq_ids:
if kv_cache.get(seq_id) is None:
kv_cache[seq_id] = [[[] for _ in range(kv_cache_size_1)]
for _ in range(kv_cache_size_0)]
for i in range(kv_cache_size_0):
for j in range(kv_cache_size_1):
kv_cache[seq_id][i][j] = past_key_values[i][j][index]
index = index + 1
def forward(
self,
seq_group_meta_data_lists: List[SequenceGroupMetadata],
kv_cache: Optional = None,
input_metadata: Optional = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
pass
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
pass

View file

@ -0,0 +1,45 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is ported from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/model_executor/utils.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for model executor."""
import random
import numpy as np
import torch
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)

View file

@ -0,0 +1,156 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/outputs.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# bigdl-llm Intel specified code change
from typing import Dict, List, Optional
from bigdl.llm.vllm.sequence import SequenceGroup, SequenceStatus
class CompletionOutput:
"""The output data of one completion output of a request.
Args:
index: The index of the output in the request.
text: The generated output text.
token_ids: The token IDs of the generated output text.
cumulative_logprob: The cumulative log probability of the generated
output text.
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
"""
def __init__(
self,
index: int,
text: str,
token_ids: List[int],
cumulative_logprob: float,
logprobs: Optional[List[Dict[int, float]]],
finish_reason: Optional[str] = None,
output_token_latency: Optional[List[float]] = None,
) -> None:
# bigdl-llm change start
# summary: add token-recording arguments
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.output_token_latency = output_token_latency
# bigdl-llm change end
def finished(self) -> bool:
return self.finish_reason is not None
def __repr__(self) -> str:
return (f"CompletionOutput(index={self.index}, "
f"text={self.text!r}, "
f"token_ids={self.token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"logprobs={self.logprobs}, "
f"finish_reason={self.finish_reason})"
f"output_token_latency={self.output_token_latency}, ")
class RequestOutput:
"""The output data of a request to the LLM.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
prompt_token_ids: The token IDs of the prompt.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
"""
def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
outputs: List[CompletionOutput],
finished: bool,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
self.finished = finished
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
# Get the top-n sequences.
n = seq_group.sampling_params.n
seqs = seq_group.get_seqs()
if seq_group.sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
seq_group.sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
# Create the outputs.
outputs: List[CompletionOutput] = []
for seq in top_n_seqs:
logprobs = seq.output_logprobs
if seq_group.sampling_params.logprobs is None:
# NOTE: We need to take care of this case because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
logprobs = {}
finshed_reason = SequenceStatus.get_finished_reason(seq.status)
output = CompletionOutput(seqs.index(seq), seq.output_text,
seq.get_output_token_ids(),
seq.get_cumulative_logprob(), logprobs,
finshed_reason,
seq.get_output_token_latency())
outputs.append(output)
# Every sequence in the sequence group should have the same prompt.
prompt = top_n_seqs[0].prompt
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
finished)
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, "
f"finished={self.finished})")

View file

@ -0,0 +1,237 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/sampling_params.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
from bigdl.llm.utils.common import invalidInputError
_SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
class SamplingParams:
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
API (https://platform.openai.com/docs/api-reference/completions/create).
In addition, we support beam search, which is not supported by OpenAI.
Args:
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
tokens.
frequency_penalty: Float that penalizes new tokens based on their
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are
generated. The returned output will contain the stop tokens unless
the stop tokens are sepcial tokens.
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
logprobs: Number of log probabilities to return per output token.
skip_special_tokens: Whether to skip special tokens in the output.
"""
def __init__(
self,
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str]=False,
stop: Optional[Union[str, List[str]]]=None,
stop_token_ids: Optional[List[int]] = None,
ignore_eos: bool = False,
max_tokens: int = 16,
logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
if stop is None:
self.stop = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = list(stop)
if stop_token_ids is None:
self.stop_token_ids = []
else:
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.logprobs = logprobs
self.skip_special_tokens = skip_special_tokens
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self._verify_greedy_sampling()
def _verify_args(self) -> None:
if self.n < 1:
invalidInputError(f"n must be at least 1, got {self.n}.")
if self.best_of < self.n:
invalidInputError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not -2.0 <= self.presence_penalty <= 2.0:
invalidInputError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
if not -2.0 <= self.frequency_penalty <= 2.0:
invalidInputError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.temperature < 0.0:
invalidInputError(
f"temperature must be non-negative, got {self.temperature}.")
if not 0.0 < self.top_p <= 1.0:
invalidInputError(f"top_p must be in (0, 1], got {self.top_p}.")
if self.top_k < -1 or self.top_k == 0:
invalidInputError(f"top_k must be -1 (disable), or at least 1, "
f"got {self.top_k}.")
if self.max_tokens < 1:
invalidInputError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
invalidInputError(
f"logprobs must be non-negative, got {self.logprobs}.")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
invalidInputError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > _SAMPLING_EPS:
invalidInputError("temperature must be 0 when using beam search.")
if self.top_p < 1.0 - _SAMPLING_EPS:
invalidInputError("top_p must be 1 when using beam search.")
if self.top_k != -1:
invalidInputError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
invalidInputError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
invalidInputError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
invalidInputError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
if self.best_of > 1:
invalidInputError("best_of must be 1 when using greedy sampling."
f"Got {self.best_of}.")
if self.top_p < 1.0 - _SAMPLING_EPS:
invalidInputError("top_p must be 1 when using greedy sampling.")
if self.top_k != -1:
invalidInputError("top_k must be -1 when using greedy sampling.")
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
return SamplingType.RANDOM
def __repr__(self) -> str:
return (f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"logprobs={self.logprobs}, "
f"skip_special_tokens={self.skip_special_tokens})")

View file

@ -0,0 +1,388 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/sequence.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sequence and its related classes."""
import copy
import enum
import time
from typing import Dict, List, Optional, Union
from bigdl.llm.vllm.sampling_params import SamplingParams
from bigdl.llm.utils.common import invalidInputError
class SequenceStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto()
RUNNING = enum.auto()
SWAPPED = enum.auto()
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
return status in [
SequenceStatus.FINISHED_STOPPED,
SequenceStatus.FINISHED_LENGTH_CAPPED,
SequenceStatus.FINISHED_ABORTED,
SequenceStatus.FINISHED_IGNORED,
]
@staticmethod
def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
if status == SequenceStatus.FINISHED_STOPPED:
finish_reason = "stop"
elif status == SequenceStatus.FINISHED_LENGTH_CAPPED:
finish_reason = "length"
elif status == SequenceStatus.FINISHED_ABORTED:
finish_reason = "abort"
elif status == SequenceStatus.FINISHED_IGNORED:
# The ignored sequences are the sequences whose prompt lengths
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
finish_reason = "length"
else:
finish_reason = None
return finish_reason
class SequenceData:
"""Data associated with a sequence.
Args:
prompt_token_ids: The token IDs of the prompt.
Attributes:
prompt_token_ids: The token IDs of the prompt.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
def __init__(
self,
prompt_token_ids: List[int],
) -> None:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0
self.created_timestamp = time.perf_counter()
self.updated_timestamp = self.created_timestamp
self.last_token_latency = 0.0
def append_token_id(self, token_id: int, logprob: float) -> None:
self.output_token_ids.append(token_id)
self.cumulative_logprob += logprob
cur_timestamp = time.perf_counter()
self.last_token_latency = cur_timestamp - self.updated_timestamp
self.updated_timestamp = cur_timestamp
def get_len(self) -> int:
return len(self.output_token_ids) + len(self.prompt_token_ids)
def get_prompt_len(self) -> int:
return len(self.prompt_token_ids)
def get_output_len(self) -> int:
return len(self.output_token_ids)
def get_token_ids(self) -> List[int]:
return self.prompt_token_ids + self.output_token_ids
def get_last_token_id(self) -> int:
if not self.output_token_ids:
return self.prompt_token_ids[-1]
return self.output_token_ids[-1]
def get_last_token_latency(self) -> float:
return self.last_token_latency
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob})")
class Sequence:
"""Stores the data, status, and block information of a sequence.
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
"""
def __init__(
self,
seq_id: int,
prompt: str,
prompt_token_ids: List[int],
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.data = SequenceData(prompt_token_ids)
self.output_logprobs: List[Dict[int, float]] = []
self.output_text = ""
self.output_token_latency: List[float] = []
self.status = SequenceStatus.WAITING
# Used for incremental detokenization
self.prefix_offset = 0
self.read_offset = 0
# Input + output tokens
self.tokens: Optional[List[str]] = None
def append_token_id(
self,
token_id: int,
logprobs: Dict[int, float],
latency: Optional[float] = None,
) -> None:
invalidInputError(token_id in logprobs, "token id should be in logprobs")
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id])
if not (latency is None):
self.output_token_latency.append(latency)
def get_len(self) -> int:
return self.data.get_len()
def get_prompt_len(self) -> int:
return self.data.get_prompt_len()
def get_output_len(self) -> int:
return self.data.get_output_len()
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()
def get_output_token_ids(self) -> List[int]:
return self.data.output_token_ids
def get_output_token_latency(self) -> List[float]:
return self.output_token_latency
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 0.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# Note: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
def fork(self, new_seq_id: int) -> "Sequence":
new_seq = copy.deepcopy(self)
new_seq.seq_id = new_seq_id
return new_seq
def __repr__(self) -> str:
return (f"Sequence(seq_id={self.seq_id}, "
f"status={self.status.name})")
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
Args:
request_id: The ID of the request.
seqs: The list of sequences.
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
"""
def __init__(
self,
request_id: str,
seqs: List[Sequence],
sampling_params: SamplingParams,
arrival_time: float,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.arrival_time = arrival_time
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
return self.sampling_params.best_of
else:
if self.sampling_params.best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences running.
return self.sampling_params.best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None:
return list(self.seqs_dict.values())
else:
return [
seq for seq in self.seqs_dict.values() if seq.status == status
]
def get_unfinished_seqs(self) -> List[Sequence]:
return [
seq for seq in self.seqs_dict.values() if not seq.is_finished()
]
def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int:
return len(self.get_unfinished_seqs())
def num_finished_seqs(self) -> int:
return len(self.get_finished_seqs())
def find(self, seq_id: int) -> Sequence:
if seq_id not in self.seqs_dict:
invalidInputError(False, f"Sequence {seq_id} not found.")
return self.seqs_dict[seq_id]
def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
invalidInputError(False, f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
invalidInputError(False, f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.get_seqs())
def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs_dict)})")
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
Args:
request_id: The ID of the request.
is_prompt: Whether the request is at prompt stage.
seq_data: The sequence data. (Seq id -> sequence data)
sampling_params: The sampling parameters used to generate the outputs.
"""
def __init__(
self,
request_id: str,
is_prompt: bool,
seq_data: Dict[int, SequenceData],
sampling_params: SamplingParams,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
self.seq_data = seq_data
self.sampling_params = sampling_params
class SequenceOutputs:
"""The model output associated with a sequence.
Args:
parent_seq_id: The ID of the parent sequence (for forking in beam
search).
output_token: The output token ID.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
def __init__(
self,
parent_seq_id: int,
output_token: int,
latency: float,
logprobs: Dict[int, float],
) -> None:
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
self.latency = latency
def __repr__(self) -> str:
return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
invalidInputError(False, "Not implemented")
return (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token
and self.logprobs == other.logprobs)
# For each sequence group, we generate a list of SequenceOutputs object,
# each of which contains one possible candidate for the next token.
SamplerOutput = List[List[SequenceOutputs]]

View file

@ -0,0 +1,14 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,201 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/transformers_utils/tokenizer.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
from transformers import (AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__)
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer_name: str,
*args,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Gets a tokenizer for the given model name via Huggingface."""
if tokenizer_mode == "slow":
if kwargs.get("use_fast", False):
invalidInputError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
and tokenizer_name != _FAST_LLAMA_TOKENIZER):
logger.info(
"For some LLaMA V1 models, initializing the fast tokenizer may "
"take a long time. To reduce the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer.")
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
*args,
trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision,
**kwargs)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
err_msg = (
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
"original tokenizer.")
invalidInputError(err_msg)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
invalidInputError(err_msg)
else:
invalidInputError(e)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead.")
return tokenizer
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
# NOTE(woosuk): The following code is slow because it runs a for loop over
# the output_tokens. In Python, running a for loop over a list can be slow
# even when the loop body is very simple.
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in tokenizer.get_added_vocab():
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
prev_tokens: Optional[List[str]],
prefix_offset: int = 0,
read_offset: int = 0,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[List[str], str, int, int]:
new_token_id = all_input_ids[-1]
# This is the first iteration for this sequence
if prev_tokens is None:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens)
output_tokens = new_tokens
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
# Subtract 1 extra to account for the generated token.
prefix_offset = max(len(output_tokens) - 6, 0)
# If the first new token is a special token, we can't skip 1 extra token
if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
read_offset = max(len(output_tokens), 0)
else:
read_offset = max(len(output_tokens) - 1, 0)
else:
# Put new_token_id in a list so skip_special_tokens is respected
new_tokens = tokenizer.convert_ids_to_tokens(
[new_token_id], skip_special_tokens=skip_special_tokens)
output_tokens = prev_tokens + new_tokens
# The prefix text is necessary only to defeat cleanup algorithms in
# the decode which decide to add a space or not depending on the
# surrounding ids.
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
if len(new_text) > len(prefix_text) and not new_text.endswith("<EFBFBD>"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
# If it's in the middle, it's probably a real invalid id generated
# by the model
new_text = new_text[len(prefix_text):]
return new_tokens, new_text, read_offset, len(output_tokens)
else:
return new_tokens, "", prefix_offset, read_offset

View file

@ -0,0 +1,60 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/utils.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
import socket
from typing import List, Optional, Tuple, Union
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from bigdl.llm.vllm.logger import init_logger
from bigdl.llm.utils.common import invalidInputError
logger = init_logger(__name__)
class Counter:
def __init__(self, start: int = 0) -> None:
self.counter = start
def __next__(self) -> int:
i = self.counter
self.counter += 1
return i
def reset(self) -> None:
self.counter = 0
def random_uuid() -> str:
return str(uuid.uuid4().hex)

View file

@ -0,0 +1,338 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/vllm/worker/worker.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A GPU worker class."""
import os
from typing import Dict, List, Tuple, Optional
import torch
import torch.distributed
import warnings
import numpy as np
import random
from bigdl.llm.vllm.config import ModelConfig, SchedulerConfig
from bigdl.llm.vllm.model_executor.model_loader import get_model
from bigdl.llm.vllm.model_executor.input_metadata import InputMetadata
from bigdl.llm.vllm.sampling_params import SamplingParams
from bigdl.llm.vllm.sequence import SequenceData, SamplerOutput, SequenceGroupMetadata
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.vllm.model_executor.utils import set_random_seed
class Worker:
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
# bigdl-llm Intel specified code change
# bigdl-llm change start
# summary: Remove config for parallel and cache engine.
# Add kv_cache dict and methods to maintain.
def __init__(
self,
model_config: ModelConfig,
# parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
rank: Optional[int] = None,
# distributed_init_method: Optional[str] = None,
) -> None:
self.model_config = model_config
# self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.rank = rank
# self.distributed_init_method = distributed_init_method
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.sliding_window = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
self.kv_cache = dict()
def clean_finished_seqs(self, finished_seqs: List[int]):
"""
This function cleans the finished sequences and their KVCache in self.kv_cache
"""
for seq_id in finished_seqs:
if seq_id not in self.kv_cache.keys():
warnings.warn(f"Duplicate key {seq_id} received during clean worker's KVCache")
continue
del self.kv_cache[seq_id]
def init_model(self):
if self.model_config.device == 'gpu':
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
invalidInputError(False, "Invalid or unspecified rank.")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
# Initialize the distributed environment.
# Co(gc): Consider this later
# _init_distributed_environment(self.parallel_config, self.rank,
# self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(self.model_config)
def _prepare_inputs(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
# Add prompt tokens.
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
if not seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
# Use any sequence in the group.
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens)))
# if seq_group_metadata.block_tables is None:
# # During memory profiling, the block tables are not initialized
# # yet. In this case, we just use a dummy slot mapping.
# slot_mapping.extend([0] * prompt_len)
# continue
# # Compute the slot mapping.
# block_table = seq_group_metadata.block_tables[seq_id]
# for i in range(prompt_len):
# block_number = block_table[i // self.block_size]
# block_offset = i % self.block_size
# slot = block_number * self.block_size + block_offset
# slot_mapping.append(slot)
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
context_lens: List[int] = []
# generation_block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
if seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
context_len = seq_data.get_len()
position = context_len - 1
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
input_positions.append(position)
# block_table = seq_group_metadata.block_tables[seq_id]
max_context_len = max(max_context_len, context_len)
# max_num_blocks_per_seq = max(max_num_blocks_per_seq,
# len(block_table))
context_lens.append(context_len)
# block_number = block_table[position // self.block_size]
# block_offset = position % self.block_size
# slot = block_number * self.block_size + block_offset
# slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
# generation_block_tables.append(block_table)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors.
tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
# device="cuda"
)
positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
# device="cuda"
)
# slot_mapping_tensor = torch.tensor(slot_mapping,
# dtype=torch.int,
# device="cuda")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
# device="cuda"
)
# padded_block_tables = [
# _pad_to_max(block_table, max_num_blocks_per_seq)
# for block_table in generation_block_tables
# ]
# block_tables_tensor = torch.tensor(padded_block_tables,
# dtype=torch.int,
# device="cuda")
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
# slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
# block_tables=block_tables_tensor,
sliding_window=self.sliding_window,
)
return tokens_tensor, positions_tensor, input_metadata
# TODO(gc): we may want to delete unused parameters
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
finished_seqs: List[int],
) -> SamplerOutput:
# Issue cache operations.
# issued_cache_op = False
# if blocks_to_swap_in:
# self.cache_engine.swap_in(blocks_to_swap_in)
# issued_cache_op = True
# if blocks_to_swap_out:
# self.cache_engine.swap_out(blocks_to_swap_out)
# issued_cache_op = True
# if blocks_to_copy:
# self.cache_engine.copy(blocks_to_copy)
# issued_cache_op = True
# if issued_cache_op:
# cache_events = self.cache_events
# else:
# cache_events = None
if finished_seqs:
self.clean_finished_seqs(finished_seqs)
cache_events = None
# If there is no input, we don't need to execute the model.
if not seq_group_metadata_list:
if cache_events is not None:
for event in cache_events:
event.wait()
return {}
# pdb.set_trace()
# TODO(gc): use environment/global virable to check
if True:
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seq_group_metadata_list)
output = self.model(
seq_group_meta_data_lists=seq_group_metadata_list,
kv_cache=self.kv_cache, input_metadata=input_metadata)
return output
else:
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seq_group_metadata_list)
# Execute the model.
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=None,
input_metadata=input_metadata,
cache_events=cache_events,
)
return output
# bigdl-llm change end
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] < 8:
gpu_name = torch.cuda.get_device_name()
invalidInputError(
False,
"Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")