Upgrade to vllm 0.6.2 (#12338)

* Initial updates for vllm 0.6.2

* fix

* Change Dockerfile to support v062

* Fix

* fix examples

* Fix

* done

* fix

* Update engine.py

* Fix Dockerfile to original path

* fix

* add option

* fix

* fix

* fix

* fix

---------

Co-authored-by: xiangyuT <xiangyu.tian@intel.com>
This commit is contained in:
Guancheng Fu 2024-11-12 20:35:34 +08:00 committed by GitHub
parent 4376fdee62
commit 0ee54fc55f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 617 additions and 520 deletions

View file

@ -5,6 +5,8 @@ ARG https_proxy
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
# To prevent RPC_TIMEOUT ERROR for the first request
ENV VLLM_RPC_TIMEOUT=100000
# Disable pip's cache behavior # Disable pip's cache behavior
@ -42,6 +44,14 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
pip install --upgrade colorama && \ pip install --upgrade colorama && \
# Download all-in-one benchmark and examples # Download all-in-one benchmark and examples
git clone https://github.com/intel-analytics/ipex-llm && \ git clone https://github.com/intel-analytics/ipex-llm && \
# The following comment segment is used when building from source...
# cd ipex-llm && \
# git fetch origin pull/12338/head:local_pr && \
# git checkout local_pr && \
# pip uninstall -y ipex-llm && \
# cd python/llm && \
# python setup.py install && \
# cd ../../../ && \
cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \ cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \ cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
# Install vllm dependencies # Install vllm dependencies
@ -76,13 +86,16 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
rm -rf /tmp/neo && \ rm -rf /tmp/neo && \
mkdir -p /llm && \ mkdir -p /llm && \
cd /llm && \ cd /llm && \
git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \ git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
cd /llm/vllm && \ cd /llm/vllm && \
pip install -r /llm/vllm/requirements-xpu.txt && \ pip install setuptools-scm && \
VLLM_TARGET_DEVICE=xpu python setup.py install && \ pip install --upgrade cmake && \
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
# pip install -r /llm/vllm/requirements-xpu.txt && \
# VLLM_TARGET_DEVICE=xpu python setup.py install && \
pip install mpi4py fastapi uvicorn openai && \ pip install mpi4py fastapi uvicorn openai && \
pip install gradio==4.43.0 && \ pip install gradio==4.43.0 && \
pip install transformers==4.44.2 && \ # pip install transformers==4.44.2 && \
# patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \ # patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
pip install ray && \ pip install ray && \
patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch

View file

@ -1,14 +1,22 @@
"""Benchmark offline inference throughput.""" """Benchmark offline inference throughput."""
import argparse import argparse
import dataclasses
import json import json
import random import random
import time import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import uvloop
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
# from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
def sample_requests( def sample_requests(
@ -29,22 +37,23 @@ def sample_requests(
dataset = [(data["conversations"][0]["value"], dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset] data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions. # Shuffle the dataset.
prompts = [prompt for prompt, _ in dataset] random.shuffle(dataset)
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences. # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset: for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
continue continue
@ -53,51 +62,18 @@ def sample_requests(
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests. return filtered_dataset
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def run_vllm( def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int, n: int,
use_beam_search: bool, low_bit: str,
trust_remote_code: bool, engine_args: EngineArgs,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
load_in_low_bit: str = "sym_int4",
max_num_batched_tokens: int = 5000,
max_num_seqs: int = 256,
) -> float: ) -> float:
from vllm import SamplingParams from vllm import SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
llm = LLM(model=model, llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit=low_bit)
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching,
load_in_low_bit=load_in_low_bit,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,)
# Add the requests to the engine. # Add the requests to the engine.
warm_prompt = "hi " * (1024 - 1) warm_prompt = "hi " * (1024 - 1)
@ -111,14 +87,14 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=0.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts, sampling_params, use_tqdm=True)
# Add the requests to the engine.
prompts: List[str] = [] prompts: List[str] = []
sampling_params: List[SamplingParams] = [] sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
@ -126,29 +102,78 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
start = time.perf_counter() use_beam_search = False
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
end = time.perf_counter()
return end - start return end - start
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
n: int,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
))
generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start
def run_hf( def run_hf(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
n: int, n: int,
use_beam_search: bool,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
@ -180,7 +205,7 @@ def run_hf(
padding=True).input_ids padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=True,
num_return_sequences=n, num_return_sequences=n,
temperature=1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
@ -205,13 +230,15 @@ def run_mii(
tensor_parallel_size: int, tensor_parallel_size: int,
output_len: int, output_len: int,
) -> float: ) -> float:
from mii import pipeline from mii import client, serve
llm = pipeline(model, tensor_parallel=tensor_parallel_size) llm = serve(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests] prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter() start = time.perf_counter()
llm(prompts, max_new_tokens=output_len) llm.generate(prompts, max_new_tokens=output_len)
end = time.perf_counter() end = time.perf_counter()
client = client(model)
client.terminate_server()
return end - start return end - start
@ -224,7 +251,16 @@ def main(args: argparse.Namespace):
args.tokenizer, trust_remote_code=args.trust_remote_code) args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None: if args.dataset is None:
# Synthesize a prompt with the given input length. # Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1) # As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for i in range(-10, 10):
prompt = "hi " * (args.input_len + i)
tokenized_prompt = tokenizer(prompt).input_ids
if len(tokenized_prompt) == args.input_len:
break
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len) requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)] for _ in range(args.num_prompts)]
else: else:
@ -232,18 +268,21 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( if args.async_engine:
requests, args.model, args.tokenizer, args.quantization, elapsed_time = uvloop.run(
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, run_vllm_async(
args.trust_remote_code, args.dtype, args.max_model_len, requests,
args.enforce_eager, args.kv_cache_dtype, args.device, args.n,
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit, AsyncEngineArgs.from_cli_args(args),
args.max_num_batched_tokens,args.max_num_seqs) args.disable_frontend_multiprocessing,
))
else:
elapsed_time = run_vllm(requests, args.n, args.load_in_low_bit,
EngineArgs.from_cli_args(args))
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size, args.hf_max_batch_size, args.trust_remote_code)
args.trust_remote_code)
elif args.backend == "mii": elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len) args.output_len)
@ -251,12 +290,26 @@ def main(args: argparse.Namespace):
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests) for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s") f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", parser.add_argument("--backend",
type=str, type=str,
choices=["vllm", "hf", "mii"], choices=["vllm", "hf", "mii"],
@ -274,89 +327,38 @@ if __name__ == "__main__":
default=None, default=None,
help="Output length for each request. Overrides the " help="Output length for each request. Overrides the "
"output length from the dataset.") "output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",
type=int, type=int,
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=1000, default=1000,
help="Number of prompts to process.") help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", parser.add_argument("--hf-max-batch-size",
type=int, type=int,
default=None, default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument( parser.add_argument(
'--max-model-len', '--output-json',
type=int, type=str,
default=None, default=None,
help='Maximum length of a sequence (including prompt and output). ' help='Path to save the throughput results in JSON format.')
'If None, will be derived from the model.') parser.add_argument("--async-engine",
parser.add_argument( action='store_true',
'--dtype', default=False,
type=str, help="Use vLLM async engine rather than LLM class.")
default='auto', parser.add_argument("--disable-frontend-multiprocessing",
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], action='store_true',
help='data type for model weights and activations. ' default=False,
'The "auto" option will use FP16 precision ' help="Disable decoupled async engine frontend.")
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument( parser.add_argument(
"--load-in-low-bit", "--load-in-low-bit",
type=str, type=str,
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"], choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
default="sym_int4", default="sym_int4",
help="Low-bit format quantization with IPEX-LLM") help="Low-bit format quantization with IPEX-LLM")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-num-batched-tokens',
type=int,
default=4096,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=256,
help='Maximum number of sequences per iteration.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
@ -379,8 +381,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.") raise ValueError("dtype must be auto for MII backend.")
if args.n != 1: if args.n != 1:
raise ValueError("n must be 1 for MII backend.") raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:
@ -388,5 +388,4 @@ if __name__ == "__main__":
if args.tokenizer != args.model: if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII " raise ValueError("Tokenizer must be the same as the model for MII "
"backend.") "backend.")
main(args)

View file

@ -28,4 +28,6 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--max-model-len 2048 \ --max-model-len 2048 \
--max-num-batched-tokens 4000 \ --max-num-batched-tokens 4000 \
--max-num-seqs 12 \ --max-num-seqs 12 \
--tensor-parallel-size 1 --tensor-parallel-size 1 \
--disable-async-output-proc \
--distributed-executor-backend ray

View file

@ -51,6 +51,8 @@ llm = LLM(model="YOUR_MODEL",
enforce_eager=True, enforce_eager=True,
load_in_low_bit="fp8", load_in_low_bit="fp8",
tensor_parallel_size=1, tensor_parallel_size=1,
disable_async_output_proc=True,
distributed_executor_backend="ray",
max_model_len=2000, max_model_len=2000,
max_num_batched_tokens=2000) max_num_batched_tokens=2000)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects

View file

@ -2,7 +2,7 @@
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations). This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.3.3). The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
Currently, we support the following models for vLLM engine: Currently, we support the following models for vLLM engine:
@ -17,7 +17,7 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
### 0. Environment ### 0. Environment
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.0. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements). To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
After install the toolkit, run the following commands in your environment before starting vLLM GPU: After install the toolkit, run the following commands in your environment before starting vLLM GPU:
```bash ```bash
@ -44,14 +44,12 @@ conda create -n ipex-vllm python=3.11
conda activate ipex-vllm conda activate ipex-vllm
# Install dependencies # Install dependencies
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
pip install setuptools-scm
pip install --upgrade cmake
# cd to your workdir # cd to your workdir
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
cd vllm cd vllm
pip install -r requirements-xpu.txt VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
pip install --no-deps xformers
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e .
pip install outlines==0.0.34 --no-deps
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy
# For Qwen model support # For Qwen model support
pip install transformers_stream_generator einops tiktoken pip install transformers_stream_generator einops tiktoken
``` ```
@ -60,7 +58,8 @@ pip install transformers_stream_generator einops tiktoken
```bash ```bash
export USE_XETLA=OFF export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export SYCL_CACHE_PERSISTENT=1
``` ```
### 3. Offline inference/Service ### 3. Offline inference/Service
@ -86,6 +85,7 @@ For vLLM, you can start the service using the following command:
#!/bin/bash #!/bin/bash
model="YOUR_MODEL_PATH" model="YOUR_MODEL_PATH"
served_model_name="YOUR_MODEL_NAME" served_model_name="YOUR_MODEL_NAME"
export VLLM_RPC_TIMEOUT=100000
# You may need to adjust the value of # You may need to adjust the value of
# --max-model-len, --max-num-batched-tokens, --max-num-seqs # --max-model-len, --max-num-batched-tokens, --max-num-seqs
@ -104,7 +104,8 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--max-model-len 4096 \ --max-model-len 4096 \
--max-num-batched-tokens 10240 \ --max-num-batched-tokens 10240 \
--max-num-seqs 12 \ --max-num-seqs 12 \
--tensor-parallel-size 1 --tensor-parallel-size 1 \
--disable-async-output-proc
``` ```
You can tune the service using these four arguments: You can tune the service using these four arguments:
@ -200,5 +201,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--max-model-len 4096 \ --max-model-len 4096 \
--max-num-batched-tokens 10240 \ --max-num-batched-tokens 10240 \
--max-num-seqs 12 \ --max-num-seqs 12 \
--tensor-parallel-size 2 --tensor-parallel-size 2 \
--distributed-executor-backend ray \
--disable-async-output-proc
``` ```

View file

@ -49,8 +49,10 @@ llm = LLM(model="YOUR_MODEL",
device="xpu", device="xpu",
dtype="float16", dtype="float16",
enforce_eager=True, enforce_eager=True,
load_in_low_bit="sym_int4", load_in_low_bit="fp8",
tensor_parallel_size=1) tensor_parallel_size=1,
max_model_len=2000,
max_num_batched_tokens=2000)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
@ -58,4 +60,4 @@ outputs = llm.generate(prompts, sampling_params)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View file

@ -93,7 +93,7 @@ class VLLMWorker(BaseModelWorker):
request_id = params.pop("request_id") request_id = params.pop("request_id")
temperature = float(params.get("temperature", 1.0)) temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0)) top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0) top_k = params.get("top_k", -1)
presence_penalty = float(params.get("presence_penalty", 0.0)) presence_penalty = float(params.get("presence_penalty", 0.0))
frequency_penalty = float(params.get("frequency_penalty", 0.0)) frequency_penalty = float(params.get("frequency_penalty", 0.0))
max_new_tokens = params.get("max_new_tokens", 256) max_new_tokens = params.get("max_new_tokens", 256)

View file

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
__all__ = [ __all__ = [
"IPEXLLMAsyncLLMEngine", "IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine", "IPEXLLMLLMEngine",
"IPEXLLMClass", "IPEXLLMClass",
"run_mp_engine",
] ]

View file

@ -19,9 +19,12 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.utils import Counter from vllm.utils import Counter
from vllm.config import EngineConfig
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.engine.metrics import StatLoggerBase from vllm.engine.metrics import StatLoggerBase
from vllm.engine.multiprocessing.engine import MQLLMEngine
import signal
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine): class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
@ -32,6 +35,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
def from_engine_args( def from_engine_args(
cls, cls,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True, start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4", load_in_low_bit: str = "sym_int4",
@ -40,7 +44,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
"""Creates an async LLM engine from the engine arguments.""" """Creates an async LLM engine from the engine arguments."""
# Create the engine configs. # Create the engine configs.
_ipex_llm_convert(load_in_low_bit) _ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers) return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)
class IPEXLLMClass(LLM): class IPEXLLMClass(LLM):
@ -117,3 +123,27 @@ class IPEXLLMLLMEngine(LLMEngine):
# Create the engine configs. # Create the engine configs.
_ipex_llm_convert(load_in_low_bit) _ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers) return super().from_engine_args(engine_args, usage_context, stat_loggers)
class IPEXLLMMQLLMEngine(MQLLMEngine):
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, ipc_path)
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, load_in_low_bit: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
signal.signal(signal.SIGTERM, signal_handler)
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path,
load_in_low_bit=load_in_low_bit)
engine.start()

View file

@ -1,25 +1,34 @@
import asyncio import asyncio
import importlib import importlib
import inspect import inspect
import multiprocessing
import os
import re import re
import signal
import socket
import tempfile
from argparse import Namespace from argparse import Namespace
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus from http import HTTPStatus
from multiprocessing import Process
from typing import AsyncIterator, Set from typing import AsyncIterator, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app from starlette.datastructures import State
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.multiprocessing.client import MQLLMEngineClient
from ipex_llm.vllm.xpu.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
@ -28,154 +37,269 @@ from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingRequest, ErrorResponse, EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoraAdapterRequest,
TokenizeRequest, TokenizeRequest,
TokenizeResponse) TokenizeResponse,
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient UnloadLoraAdapterRequest)
from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_tokenization import ( from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization) OpenAIServingTokenization)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, get_open_port from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
async_engine_client: AsyncEngineClient prometheus_multiproc_dir: tempfile.TemporaryDirectory
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server') logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set() _running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log(): async def _force_log():
while True: while True:
await asyncio.sleep(10) await asyncio.sleep(10.)
await async_engine_client.do_log_stats() await engine_client.do_log_stats()
if not engine_args.disable_log_stats: task = asyncio.create_task(_force_log())
task = asyncio.create_task(_force_log()) _running_tasks.add(task)
_running_tasks.add(task) task.add_done_callback(_running_tasks.remove)
task.add_done_callback(_running_tasks.remove) else:
task = None
yield try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager @asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: async def build_async_engine_client(
# Context manager to handle async_engine_client lifecycle args: Namespace) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit # Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler async with build_async_engine_client_from_engine_args(
global async_engine_client engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
yield engine
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC. @asynccontextmanager
if (model_is_embedding(args.model, args.trust_remote_code) async def build_async_engine_client_from_engine_args(
or args.disable_frontend_multiprocessing): engine_args: AsyncEngineArgs,
async_engine_client = AsyncLLMEngine.from_engine_args( disable_frontend_multiprocessing: bool = False,
engine_args, usage_context=UsageContext.OPENAI_API_SERVER, load_in_low_bit: str = 'sym_int4',
load_in_low_bit=args.load_in_low_bit) ) -> AsyncIterator[EngineClient]:
yield async_engine_client """
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_args=engine_args,
load_in_low_bit=load_in_low_bit,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
return return
# Otherwise, use the multiprocessing AsyncLLMEngine. # Otherwise, use the multiprocessing AsyncLLMEngine.
else: else:
# Start RPCServer in separate process (holds the AsyncLLMEngine). if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
port = get_open_port(envs.VLLM_RPC_PORT) # Make TemporaryDirectory for prometheus multiprocessing
load_in_low_bit = args.load_in_low_bit # Note: global TemporaryDirectory will be automatically
rpc_server_process = Process(target=run_rpc_server, # cleaned up upon exit.
args=(engine_args, global prometheus_multiproc_dir
UsageContext.OPENAI_API_SERVER, prometheus_multiproc_dir = tempfile.TemporaryDirectory()
port, load_in_low_bit)) os.environ[
rpc_server_process.start() "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Build RPCClient, which conforms to AsyncEngineClient Protocol. # Select random path for IPC.
async_engine_client = AsyncEngineRPCClient(port) ipc_path = get_open_zmq_ipc_path()
await async_engine_client.setup() logger.info("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path,
load_in_low_bit))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
try: try:
yield async_engine_client while True:
try:
await mp_engine_client.setup()
break
except TimeoutError:
if not engine_process.is_alive():
raise RuntimeError(
"Engine process failed to start") from None
yield mp_engine_client # type: ignore[misc]
finally: finally:
# Ensure rpc server process was terminated # Ensure rpc server process was terminated
rpc_server_process.terminate() engine_process.terminate()
# Close all open connections to the backend # Close all open connections to the backend
async_engine_client.close() mp_engine_client.close()
# Wait for server process to join # Wait for engine process to join
rpc_server_process.join() engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter() router = APIRouter()
def mount_metrics(app: FastAPI): def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests # Lazy import for prometheus multiprocessing.
metrics_route = Mount("/metrics", make_asgi_app()) # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics # Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$') metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route) app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health") @router.get("/health")
async def health() -> Response: async def health(raw_request: Request) -> Response:
"""Health check.""" """Health check."""
await async_engine_client.check_health() await engine_client(raw_request).check_health()
return Response(status_code=200) return Response(status_code=200)
@router.post("/tokenize") @router.post("/tokenize")
async def tokenize(request: TokenizeRequest): async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await openai_serving_tokenization.create_tokenize(request) generator = await tokenization(raw_request).create_tokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, TokenizeResponse):
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize") @router.post("/detokenize")
async def detokenize(request: DetokenizeRequest): async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await openai_serving_tokenization.create_detokenize(request) generator = await tokenization(raw_request).create_detokenize(request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, DetokenizeResponse):
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models") @router.get("/v1/models")
async def show_available_models(): async def show_available_models(raw_request: Request):
models = await openai_serving_completion.show_available_models() models = await completion(raw_request).show_available_models()
return JSONResponse(content=models.model_dump()) return JSONResponse(content=models.model_dump())
@ -188,46 +312,110 @@ async def show_version():
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
generator = await chat(raw_request).create_chat_completion(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator, elif isinstance(generator, ChatCompletionResponse):
media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions") @router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion( generator = await completion(raw_request).create_completion(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
if request.stream: elif isinstance(generator, CompletionResponse):
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings") @router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding( generator = await embedding(raw_request).create_embedding(
request, raw_request) request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.code)
else: elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
def build_app(args: Namespace) -> FastAPI: def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan) if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
docs_url=None,
redoc_url=None,
lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.include_router(router) app.include_router(router)
app.root_path = args.root_path app.root_path = args.root_path
@ -243,7 +431,8 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc): async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc)) chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
@ -275,74 +464,87 @@ def build_app(args: Namespace) -> FastAPI:
return app return app
async def init_app( def init_app_state(
async_engine_client: AsyncEngineClient, engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace, args: Namespace,
) -> FastAPI: ) -> None:
app = build_app(args)
if args.served_model_name is not None: if args.served_model_name is not None:
served_model_names = args.served_model_name served_model_names = args.served_model_name
else: else:
served_model_names = [args.model] served_model_names = [args.model]
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests: if args.disable_log_requests:
request_logger = None request_logger = None
else: else:
request_logger = RequestLogger(max_log_len=args.max_log_len) request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat base_model_paths = [
global openai_serving_completion BaseModelPath(name=name, model_path=args.model)
global openai_serving_embedding for name in served_model_names
global openai_serving_tokenization ]
openai_serving_chat = OpenAIServingChat( state.engine_client = engine_client
async_engine_client, state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config, model_config,
served_model_names, base_model_paths,
args.response_role, args.response_role,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
chat_template=args.chat_template, chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) enable_auto_tools=args.enable_auto_tool_choice,
openai_serving_completion = OpenAIServingCompletion( tool_parser=args.tool_call_parser)
async_engine_client, state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config, model_config,
served_model_names, base_model_paths,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters, prompt_adapters=args.prompt_adapters,
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) )
openai_serving_embedding = OpenAIServingEmbedding( state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client, engine_client,
model_config, model_config,
served_model_names, base_model_paths,
request_logger=request_logger, request_logger=request_logger,
) )
openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client, engine_client,
model_config, model_config,
served_model_names, base_model_paths,
lora_modules=args.lora_modules, lora_modules=args.lora_modules,
request_logger=request_logger, request_logger=request_logger,
chat_template=args.chat_template, chat_template=args.chat_template,
) )
app.root_path = args.root_path
return app
async def run_server(args, **uvicorn_kwargs) -> None: async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
async with build_async_engine_client(args) as async_engine_client: temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
app = await init_app(async_engine_client, args) temp_socket.bind(("", args.port))
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
temp_socket.close()
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
@ -369,4 +571,4 @@ if __name__ == "__main__":
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run_server(args)) uvloop.run(run_server(args))

View file

@ -7,6 +7,7 @@ purposes.
import argparse import argparse
import json import json
import ssl import ssl
from typing import List, Optional, Sequence, Union
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
@ -16,18 +17,55 @@ from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action): class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(
lora_list = [] self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list") # noqa
lora_list: List[LoRAModulePath] = []
for item in values: for item in values:
name, path = item.split('=') if item in [None, '']: # Skip if item is None or empty string
lora_list.append(LoRAModulePath(name, path)) continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list) setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action): class PromptAdapterParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None): def __call__(
adapter_list = [] self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list") # noqa
adapter_list: List[PromptAdapterPath] = []
for item in values: for item in values:
name, path = item.split('=') name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path)) adapter_list.append(PromptAdapterPath(name, path))
@ -72,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
help="LoRA module configurations in the format name=path. " help="LoRA module configurations in either 'name=path' format"
"Multiple modules can be specified.") "or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
parser.add_argument( parser.add_argument(
"--prompt-adapters", "--prompt-adapters",
type=nullable_str, type=nullable_str,
@ -91,8 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=nullable_str, type=nullable_str,
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if `request.add_generation_prompt=true`.")
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile", parser.add_argument("--ssl-keyfile",
type=nullable_str, type=nullable_str,
default=None, default=None,
@ -139,6 +180,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action="store_true", action="store_true",
help="If specified, will run the OpenAI frontend server in the same " help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.") "process as the model serving engine.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
default=False,
help="Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use")
parser.add_argument(
"--tool-call-parser",
type=str,
choices=["mistral", "hermes"],
default=None,
help="Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser.add_argument( parser.add_argument(
"--load-in-low-bit", "--load-in-low-bit",
type=str, type=str,
@ -154,6 +212,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'ID numbers being printed in log.' 'ID numbers being printed in log.'
'\n\nDefault: Unlimited') '\n\nDefault: Unlimited')
parser.add_argument(
"--disable-fastapi-docs",
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
return parser return parser

View file

@ -1,221 +0,0 @@
import asyncio
import signal
from typing import Any, Coroutine
import cloudpickle
import zmq
import zmq.asyncio
from typing_extensions import Never
from vllm import AsyncEngineArgs
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCGenerateRequest, RPCUtilityRequest)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
class AsyncEngineRPCServer:
def __init__(self, async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int, load_in_low_bit: str):
# Initialize engine first.
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
usage_context=usage_context,
load_in_low_bit=load_in_low_bit)
# Initialize context.
self.context = zmq.asyncio.Context()
# Init socket for readiness state.
self.socket = self.context.socket(zmq.constants.ROUTER)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self.socket.bind(f"tcp://127.0.0.1:{port}")
def cleanup(self):
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
async def get_model_config(self, identity):
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(model_config)])
async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(decoding_config)])
async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(lora_config)])
async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(parallel_config)])
async def is_tracing_enabled(self, identity):
"""Send the is_tracing_enabled flag"""
tracing_flag = await self.engine.is_tracing_enabled()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(tracing_flag)])
async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
results_generator = self.engine.generate(
generate_request.inputs,
sampling_params=generate_request.sampling_params,
request_id=generate_request.request_id,
lora_request=generate_request.lora_request,
trace_headers=generate_request.trace_headers,
prompt_adapter_request=generate_request.prompt_adapter_request)
async for request_output in results_generator:
await self.socket.send_multipart(
[identity, cloudpickle.dumps(request_output)])
except Exception as e:
# Notify client of all failures
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
async def check_health(self, identity):
try:
await self.engine.check_health()
await self.socket.send_multipart(
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
except Exception as e:
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
"""Route the zmq message to the handler coroutine."""
request = cloudpickle.loads(message)
if isinstance(request, RPCGenerateRequest):
return self.generate(identity, request)
elif isinstance(request, RPCAbortRequest):
return self.abort(identity, request)
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
return self.is_server_ready(identity)
elif request == RPCUtilityRequest.CHECK_HEALTH:
return self.check_health(identity)
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
return self.is_tracing_enabled(identity)
else:
raise ValueError(f"Unknown RPCUtilityRequest type: {request}") # noqa
else:
raise ValueError(f"Unknown RPCRequest type: {request}") # noqa
async def run_server_loop(self):
"""Inner RPC Server Loop"""
running_tasks = set()
while True:
# Wait for a request.
identity, message = await self.socket.recv_multipart()
# Process the request async.
task = asyncio.create_task(
self._make_handler_coro(identity, message))
# We need to keep around a strong reference to the task,
# to avoid the task disappearing mid-execution as running tasks
# can be GC'ed. Below is a common "fire-and-forget" tasks
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
async def run_server(server: AsyncEngineRPCServer):
# Put the server task into the asyncio loop.
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.run_server_loop())
# Interruption handling.
def signal_handler() -> None:
# Kill the server on interrupt / terminate
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("vLLM ZMQ RPC Server was interrupted.")
finally:
# Clean up all resources.
server.cleanup()
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, port: int, load_in_low_bit: str):
server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit)
asyncio.run(run_server(server))

View file

@ -75,14 +75,13 @@ def get_load_function(low_bit):
_model_sample_convert() _model_sample_convert()
# from vllm.utils import measure_device_memory # from vllm.utils import measure_device_memory
from vllm.utils import CudaMemoryProfiler from vllm.utils import DeviceMemoryProfiler
with CudaMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
self.model = get_model( self.model = get_model(
model_config=self.model_config, model_config=self.model_config,
device_config=DeviceConfig("cpu"), device_config=DeviceConfig("cpu"),
load_config=self.load_config, load_config=self.load_config,
lora_config=self.lora_config, lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
cache_config=self.cache_config, cache_config=self.cache_config,