Upgrade to vllm 0.6.2 (#12338)

* Initial updates for vllm 0.6.2

* fix

* Change Dockerfile to support v062

* Fix

* fix examples

* Fix

* done

* fix

* Update engine.py

* Fix Dockerfile to original path

* fix

* add option

* fix

* fix

* fix

* fix

---------

Co-authored-by: xiangyuT <xiangyu.tian@intel.com>
This commit is contained in:
Guancheng Fu 2024-11-12 20:35:34 +08:00 committed by GitHub
parent 4376fdee62
commit 0ee54fc55f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 617 additions and 520 deletions

View file

@ -5,6 +5,8 @@ ARG https_proxy
ENV TZ=Asia/Shanghai
ENV PYTHONUNBUFFERED=1
# To prevent RPC_TIMEOUT ERROR for the first request
ENV VLLM_RPC_TIMEOUT=100000
# Disable pip's cache behavior
@ -42,6 +44,14 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
pip install --upgrade colorama && \
# Download all-in-one benchmark and examples
git clone https://github.com/intel-analytics/ipex-llm && \
# The following comment segment is used when building from source...
# cd ipex-llm && \
# git fetch origin pull/12338/head:local_pr && \
# git checkout local_pr && \
# pip uninstall -y ipex-llm && \
# cd python/llm && \
# python setup.py install && \
# cd ../../../ && \
cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
# Install vllm dependencies
@ -76,13 +86,16 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
rm -rf /tmp/neo && \
mkdir -p /llm && \
cd /llm && \
git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
cd /llm/vllm && \
pip install -r /llm/vllm/requirements-xpu.txt && \
VLLM_TARGET_DEVICE=xpu python setup.py install && \
pip install setuptools-scm && \
pip install --upgrade cmake && \
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
# pip install -r /llm/vllm/requirements-xpu.txt && \
# VLLM_TARGET_DEVICE=xpu python setup.py install && \
pip install mpi4py fastapi uvicorn openai && \
pip install gradio==4.43.0 && \
pip install transformers==4.44.2 && \
# pip install transformers==4.44.2 && \
# patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
pip install ray && \
patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch

View file

@ -1,14 +1,22 @@
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import random
import time
from typing import List, Optional, Tuple
import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
# from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
def sample_requests(
@ -29,22 +37,23 @@ def sample_requests(
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out too long sequences.
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
@ -53,51 +62,18 @@ def sample_requests(
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
return filtered_dataset
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
load_in_low_bit: str = "sym_int4",
max_num_batched_tokens: int = 5000,
max_num_seqs: int = 256,
low_bit: str,
engine_args: EngineArgs,
) -> float:
from vllm import SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching,
load_in_low_bit=load_in_low_bit,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,)
llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit=low_bit)
# Add the requests to the engine.
warm_prompt = "hi " * (1024 - 1)
@ -111,14 +87,14 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=0.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
llm.generate(prompts, sampling_params, use_tqdm=True)
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
@ -126,29 +102,78 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
end = time.perf_counter()
return end - start
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
@ -180,7 +205,7 @@ def run_hf(
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
@ -205,13 +230,15 @@ def run_mii(
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import pipeline
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
from mii import client, serve
llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm(prompts, max_new_tokens=output_len)
llm.generate(prompts, max_new_tokens=output_len)
end = time.perf_counter()
client = client(model)
client.terminate_server()
return end - start
@ -224,7 +251,16 @@ def main(args: argparse.Namespace):
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for i in range(-10, 10):
prompt = "hi " * (args.input_len + i)
tokenized_prompt = tokenizer(prompt).input_ids
if len(tokenized_prompt) == args.input_len:
break
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
@ -232,18 +268,21 @@ def main(args: argparse.Namespace):
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.device,
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit,
args.max_num_batched_tokens,args.max_num_seqs)
if args.async_engine:
elapsed_time = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
))
else:
elapsed_time = run_vllm(requests, args.n, args.load_in_low_bit,
EngineArgs.from_cli_args(args))
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
args.hf_max_batch_size, args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
@ -251,12 +290,26 @@ def main(args: argparse.Namespace):
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
@ -274,89 +327,38 @@ if __name__ == "__main__":
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
'--output-json',
type=str,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument(
"--load-in-low-bit",
type=str,
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
default="sym_int4",
help="Low-bit format quantization with IPEX-LLM")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=4096,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=256,
help='Maximum number of sequences per iteration.')
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
@ -379,8 +381,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
@ -388,5 +388,4 @@ if __name__ == "__main__":
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
main(args)

View file

@ -28,4 +28,6 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--max-model-len 2048 \
--max-num-batched-tokens 4000 \
--max-num-seqs 12 \
--tensor-parallel-size 1
--tensor-parallel-size 1 \
--disable-async-output-proc \
--distributed-executor-backend ray

View file

@ -51,6 +51,8 @@ llm = LLM(model="YOUR_MODEL",
enforce_eager=True,
load_in_low_bit="fp8",
tensor_parallel_size=1,
disable_async_output_proc=True,
distributed_executor_backend="ray",
max_model_len=2000,
max_num_batched_tokens=2000)
# Generate texts from the prompts. The output is a list of RequestOutput objects

View file

@ -2,7 +2,7 @@
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.3.3).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
Currently, we support the following models for vLLM engine:
@ -17,7 +17,7 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
### 0. Environment
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.0. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements).
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
```bash
@ -44,14 +44,12 @@ conda create -n ipex-vllm python=3.11
conda activate ipex-vllm
# Install dependencies
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install setuptools-scm
pip install --upgrade cmake
# cd to your workdir
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
cd vllm
pip install -r requirements-xpu.txt
pip install --no-deps xformers
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e .
pip install outlines==0.0.34 --no-deps
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
# For Qwen model support
pip install transformers_stream_generator einops tiktoken
```
@ -60,7 +58,8 @@ pip install transformers_stream_generator einops tiktoken
```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export SYCL_CACHE_PERSISTENT=1
```
### 3. Offline inference/Service
@ -86,6 +85,7 @@ For vLLM, you can start the service using the following command:
#!/bin/bash
model="YOUR_MODEL_PATH"
served_model_name="YOUR_MODEL_NAME"
export VLLM_RPC_TIMEOUT=100000
# You may need to adjust the value of
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
@ -104,7 +104,8 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \
--tensor-parallel-size 1
--tensor-parallel-size 1 \
--disable-async-output-proc
```
You can tune the service using these four arguments:
@ -200,5 +201,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \
--tensor-parallel-size 2
--tensor-parallel-size 2 \
--distributed-executor-backend ray \
--disable-async-output-proc
```

View file

@ -49,8 +49,10 @@ llm = LLM(model="YOUR_MODEL",
device="xpu",
dtype="float16",
enforce_eager=True,
load_in_low_bit="sym_int4",
tensor_parallel_size=1)
load_in_low_bit="fp8",
tensor_parallel_size=1,
max_model_len=2000,
max_num_batched_tokens=2000)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
@ -58,4 +60,4 @@ outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View file

@ -93,7 +93,7 @@ class VLLMWorker(BaseModelWorker):
request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
top_k = params.get("top_k", -1)
presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256)

View file

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
__all__ = [
"IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine",
"IPEXLLMClass",
"run_mp_engine",
]

View file

@ -19,9 +19,12 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.llm import LLM
from vllm.utils import Counter
from vllm.config import EngineConfig
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
from vllm.usage.usage_lib import UsageContext
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.multiprocessing.engine import MQLLMEngine
import signal
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
@ -32,6 +35,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4",
@ -40,7 +44,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers)
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)
class IPEXLLMClass(LLM):
@ -117,3 +123,27 @@ class IPEXLLMLLMEngine(LLMEngine):
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers)
class IPEXLLMMQLLMEngine(MQLLMEngine):
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, ipc_path)
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, load_in_low_bit: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
signal.signal(signal.SIGTERM, signal_handler)
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path,
load_in_low_bit=load_in_low_bit)
engine.start()

View file

@ -1,25 +1,34 @@
import asyncio
import importlib
import inspect
import multiprocessing
import os
import re
import signal
import socket
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from multiprocessing import Process
from typing import AsyncIterator, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from ipex_llm.vllm.xpu.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
@ -28,154 +37,269 @@ from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoraAdapterRequest,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server
TokenizeResponse,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
async def _force_log():
while True:
await asyncio.sleep(10.)
await engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
yield
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
yield engine
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model, args.trust_remote_code)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER,
load_in_low_bit=args.load_in_low_bit)
yield async_engine_client
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
load_in_low_bit: str = 'sym_int4',
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_args=engine_args,
load_in_low_bit=load_in_low_bit,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
load_in_low_bit = args.load_in_low_bit
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port, load_in_low_bit))
rpc_server_process.start()
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
# Select random path for IPC.
ipc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path,
load_in_low_bit))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
try:
yield async_engine_client
while True:
try:
await mp_engine_client.setup()
break
except TimeoutError:
if not engine_process.is_alive():
raise RuntimeError(
"Engine process failed to start") from None
yield mp_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
engine_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
mp_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter()
def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health")
async def health() -> Response:
async def health(raw_request: Request) -> Response:
"""Health check."""
await async_engine_client.check_health()
await engine_client(raw_request).check_health()
return Response(status_code=200)
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_tokenization.create_tokenize(request)
async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_tokenization.create_detokenize(request)
async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_completion.show_available_models()
async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models()
return JSONResponse(content=models.model_dump())
@ -188,46 +312,110 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
generator = await chat(raw_request).create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
generator = await completion(raw_request).create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
generator = await embedding(raw_request).create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
docs_url=None,
redoc_url=None,
lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
@ -243,7 +431,8 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
@ -275,74 +464,87 @@ def build_app(args: Namespace) -> FastAPI:
return app
async def init_app(
async_engine_client: AsyncEngineClient,
def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
) -> FastAPI:
app = build_app(args)
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global openai_serving_tokenization
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
openai_serving_chat = OpenAIServingChat(
async_engine_client,
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
served_model_names,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client,
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
served_model_names,
base_model_paths,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client,
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
served_model_names,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
app.root_path = args.root_path
return app
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
async with build_async_engine_client(args) as async_engine_client:
app = await init_app(async_engine_client, args)
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
temp_socket.bind(("", args.port))
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
temp_socket.close()
shutdown_task = await serve_http(
app,
@ -369,4 +571,4 @@ if __name__ == "__main__":
parser = make_arg_parser(parser)
args = parser.parse_args()
asyncio.run(run_server(args))
uvloop.run(run_server(args))

View file

@ -7,6 +7,7 @@ purposes.
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
@ -16,18 +17,55 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list") # noqa
lora_list: List[LoRAModulePath] = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
if item in [None, '']: # Skip if item is None or empty string
continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list") # noqa
adapter_list: List[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
@ -72,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
@ -91,8 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
help="The role name to return if `request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
@ -139,6 +180,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
default=False,
help="Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
default=None,
help="Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser.add_argument(
"--load-in-low-bit",
type=str,
@ -154,6 +212,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument(
"--disable-fastapi-docs",
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
return parser

View file

@ -1,221 +0,0 @@
import asyncio
import signal
from typing import Any, Coroutine
import cloudpickle
import zmq
import zmq.asyncio
from typing_extensions import Never
from vllm import AsyncEngineArgs
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int, load_in_low_bit: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context=usage_context,
load_in_low_bit=load_in_low_bit)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self.socket.bind(f"tcp://127.0.0.1:{port}")
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])
async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])
async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
except Exception as e:
# Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}") # noqa
else:
raise ValueError(f"Unknown RPCRequest type: {request}") # noqa
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int, load_in_low_bit: str):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit)
asyncio.run(run_server(server))

View file

@ -75,14 +75,13 @@ def get_load_function(low_bit):
_model_sample_convert()
# from vllm.utils import measure_device_memory
from vllm.utils import CudaMemoryProfiler
with CudaMemoryProfiler() as m:
from vllm.utils import DeviceMemoryProfiler
with DeviceMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
device_config=DeviceConfig("cpu"),
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,