Upgrade to vllm 0.6.2 (#12338)
* Initial updates for vllm 0.6.2 * fix * Change Dockerfile to support v062 * Fix * fix examples * Fix * done * fix * Update engine.py * Fix Dockerfile to original path * fix * add option * fix * fix * fix * fix --------- Co-authored-by: xiangyuT <xiangyu.tian@intel.com>
This commit is contained in:
parent
4376fdee62
commit
0ee54fc55f
13 changed files with 617 additions and 520 deletions
|
|
@ -5,6 +5,8 @@ ARG https_proxy
|
|||
|
||||
ENV TZ=Asia/Shanghai
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
# To prevent RPC_TIMEOUT ERROR for the first request
|
||||
ENV VLLM_RPC_TIMEOUT=100000
|
||||
|
||||
|
||||
# Disable pip's cache behavior
|
||||
|
|
@ -42,6 +44,14 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
|||
pip install --upgrade colorama && \
|
||||
# Download all-in-one benchmark and examples
|
||||
git clone https://github.com/intel-analytics/ipex-llm && \
|
||||
# The following comment segment is used when building from source...
|
||||
# cd ipex-llm && \
|
||||
# git fetch origin pull/12338/head:local_pr && \
|
||||
# git checkout local_pr && \
|
||||
# pip uninstall -y ipex-llm && \
|
||||
# cd python/llm && \
|
||||
# python setup.py install && \
|
||||
# cd ../../../ && \
|
||||
cp -r ./ipex-llm/python/llm/dev/benchmark/ ./benchmark && \
|
||||
cp -r ./ipex-llm/python/llm/example/GPU/HuggingFace/LLM ./examples && \
|
||||
# Install vllm dependencies
|
||||
|
|
@ -76,13 +86,16 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
|
|||
rm -rf /tmp/neo && \
|
||||
mkdir -p /llm && \
|
||||
cd /llm && \
|
||||
git clone -b 0.5.4 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
|
||||
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git /llm/vllm && \
|
||||
cd /llm/vllm && \
|
||||
pip install -r /llm/vllm/requirements-xpu.txt && \
|
||||
VLLM_TARGET_DEVICE=xpu python setup.py install && \
|
||||
pip install setuptools-scm && \
|
||||
pip install --upgrade cmake && \
|
||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
|
||||
# pip install -r /llm/vllm/requirements-xpu.txt && \
|
||||
# VLLM_TARGET_DEVICE=xpu python setup.py install && \
|
||||
pip install mpi4py fastapi uvicorn openai && \
|
||||
pip install gradio==4.43.0 && \
|
||||
pip install transformers==4.44.2 && \
|
||||
# pip install transformers==4.44.2 && \
|
||||
# patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch && \
|
||||
pip install ray && \
|
||||
patch /usr/local/lib/python3.11/dist-packages/fastchat/serve/gradio_web_server.py < /tmp/gradio_web_server.patch
|
||||
|
|
|
|||
|
|
@ -1,14 +1,22 @@
|
|||
"""Benchmark offline inference throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
# from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
|
|
@ -29,22 +37,23 @@ def sample_requests(
|
|||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
prompt_token_ids = tokenizer(prompts).input_ids
|
||||
completions = [completion for _, completion in dataset]
|
||||
completion_token_ids = tokenizer(completions).input_ids
|
||||
tokenized_dataset = []
|
||||
for i in range(len(dataset)):
|
||||
output_len = len(completion_token_ids[i])
|
||||
if fixed_output_len is not None:
|
||||
output_len = fixed_output_len
|
||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out too long sequences.
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
for prompt, prompt_token_ids, output_len in tokenized_dataset:
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
|
|
@ -53,51 +62,18 @@ def sample_requests(
|
|||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
# Sample the requests.
|
||||
sampled_requests = random.sample(filtered_dataset, num_requests)
|
||||
return sampled_requests
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
max_num_batched_tokens: int = 5000,
|
||||
max_num_seqs: int = 256,
|
||||
low_bit: str,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
|
||||
llm = LLM(model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,)
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args), load_in_low_bit=low_bit)
|
||||
|
||||
# Add the requests to the engine.
|
||||
warm_prompt = "hi " * (1024 - 1)
|
||||
|
|
@ -111,14 +87,14 @@ def run_vllm(
|
|||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
|
|
@ -126,29 +102,78 @@ def run_vllm(
|
|||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
use_beam_search = False
|
||||
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
for prompt, input_len, _output_len in requests:
|
||||
assert _output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(prompts,
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True)
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
assert not use_beam_search
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
|
|
@ -180,7 +205,7 @@ def run_hf(
|
|||
padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=not use_beam_search,
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
|
|
@ -205,13 +230,15 @@ def run_mii(
|
|||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import pipeline
|
||||
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
||||
from mii import client, serve
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm(prompts, max_new_tokens=output_len)
|
||||
llm.generate(prompts, max_new_tokens=output_len)
|
||||
end = time.perf_counter()
|
||||
client = client(model)
|
||||
client.terminate_server()
|
||||
return end - start
|
||||
|
||||
|
||||
|
|
@ -224,7 +251,16 @@ def main(args: argparse.Namespace):
|
|||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
# Synthesize a prompt with the given input length.
|
||||
prompt = "hi" * (args.input_len - 1)
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for i in range(-10, 10):
|
||||
prompt = "hi " * (args.input_len + i)
|
||||
tokenized_prompt = tokenizer(prompt).input_ids
|
||||
if len(tokenized_prompt) == args.input_len:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Failed to synthesize a prompt with {args.input_len} tokens.")
|
||||
requests = [(prompt, args.input_len, args.output_len)
|
||||
for _ in range(args.num_prompts)]
|
||||
else:
|
||||
|
|
@ -232,18 +268,21 @@ def main(args: argparse.Namespace):
|
|||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
args.enforce_eager, args.kv_cache_dtype, args.device,
|
||||
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit,
|
||||
args.max_num_batched_tokens,args.max_num_seqs)
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
))
|
||||
else:
|
||||
elapsed_time = run_vllm(requests, args.n, args.load_in_low_bit,
|
||||
EngineArgs.from_cli_args(args))
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.use_beam_search, args.hf_max_batch_size,
|
||||
args.trust_remote_code)
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
|
|
@ -251,12 +290,26 @@ def main(args: argparse.Namespace):
|
|||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
total_output_tokens = sum(output_len for _, _, output_len in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
|
|
@ -274,89 +327,38 @@ if __name__ == "__main__":
|
|||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--max-model-len',
|
||||
type=int,
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Maximum length of a sequence (including prompt and output). '
|
||||
'If None, will be derived from the model.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=0.9,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
|
||||
default="sym_int4",
|
||||
help="Low-bit format quantization with IPEX-LLM")
|
||||
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=4096,
|
||||
help='maximum number of batched tokens per iteration')
|
||||
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=256,
|
||||
help='Maximum number of sequences per iteration.')
|
||||
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
|
@ -379,8 +381,6 @@ if __name__ == "__main__":
|
|||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.use_beam_search:
|
||||
raise ValueError("Beam search is not supported for MII backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.hf_max_batch_size is not None:
|
||||
|
|
@ -388,5 +388,4 @@ if __name__ == "__main__":
|
|||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
main(args)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,4 +28,6 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
|||
--max-model-len 2048 \
|
||||
--max-num-batched-tokens 4000 \
|
||||
--max-num-seqs 12 \
|
||||
--tensor-parallel-size 1
|
||||
--tensor-parallel-size 1 \
|
||||
--disable-async-output-proc \
|
||||
--distributed-executor-backend ray
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ llm = LLM(model="YOUR_MODEL",
|
|||
enforce_eager=True,
|
||||
load_in_low_bit="fp8",
|
||||
tensor_parallel_size=1,
|
||||
disable_async_output_proc=True,
|
||||
distributed_executor_backend="ray",
|
||||
max_model_len=2000,
|
||||
max_num_batched_tokens=2000)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with IPEX-LLM low-bits optimizations).
|
||||
|
||||
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.3.3).
|
||||
The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.6.2).
|
||||
|
||||
Currently, we support the following models for vLLM engine:
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-
|
|||
|
||||
### 0. Environment
|
||||
|
||||
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.0. Please check the requirements at [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU#requirements).
|
||||
To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit 2024.1. Please check the requirements at [here](https://www.intel.com/content/www/us/en/docs/oneapi/installation-guide-linux/2024-1/overview.html).
|
||||
|
||||
After install the toolkit, run the following commands in your environment before starting vLLM GPU:
|
||||
```bash
|
||||
|
|
@ -44,14 +44,12 @@ conda create -n ipex-vllm python=3.11
|
|||
conda activate ipex-vllm
|
||||
# Install dependencies
|
||||
pip install --pre --upgrade "ipex-llm[xpu]" --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
pip install setuptools-scm
|
||||
pip install --upgrade cmake
|
||||
# cd to your workdir
|
||||
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git
|
||||
git clone -b 0.6.2 https://github.com/analytics-zoo/vllm.git
|
||||
cd vllm
|
||||
pip install -r requirements-xpu.txt
|
||||
pip install --no-deps xformers
|
||||
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e .
|
||||
pip install outlines==0.0.34 --no-deps
|
||||
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy
|
||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v .
|
||||
# For Qwen model support
|
||||
pip install transformers_stream_generator einops tiktoken
|
||||
```
|
||||
|
|
@ -60,7 +58,8 @@ pip install transformers_stream_generator einops tiktoken
|
|||
|
||||
```bash
|
||||
export USE_XETLA=OFF
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
|
||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||
export SYCL_CACHE_PERSISTENT=1
|
||||
```
|
||||
### 3. Offline inference/Service
|
||||
|
||||
|
|
@ -86,6 +85,7 @@ For vLLM, you can start the service using the following command:
|
|||
#!/bin/bash
|
||||
model="YOUR_MODEL_PATH"
|
||||
served_model_name="YOUR_MODEL_NAME"
|
||||
export VLLM_RPC_TIMEOUT=100000
|
||||
|
||||
# You may need to adjust the value of
|
||||
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
|
||||
|
|
@ -104,7 +104,8 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
|||
--max-model-len 4096 \
|
||||
--max-num-batched-tokens 10240 \
|
||||
--max-num-seqs 12 \
|
||||
--tensor-parallel-size 1
|
||||
--tensor-parallel-size 1 \
|
||||
--disable-async-output-proc
|
||||
```
|
||||
|
||||
You can tune the service using these four arguments:
|
||||
|
|
@ -200,5 +201,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
|||
--max-model-len 4096 \
|
||||
--max-num-batched-tokens 10240 \
|
||||
--max-num-seqs 12 \
|
||||
--tensor-parallel-size 2
|
||||
--tensor-parallel-size 2 \
|
||||
--distributed-executor-backend ray \
|
||||
--disable-async-output-proc
|
||||
```
|
||||
|
|
|
|||
|
|
@ -49,8 +49,10 @@ llm = LLM(model="YOUR_MODEL",
|
|||
device="xpu",
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
load_in_low_bit="sym_int4",
|
||||
tensor_parallel_size=1)
|
||||
load_in_low_bit="fp8",
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=2000,
|
||||
max_num_batched_tokens=2000)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
|
@ -58,4 +60,4 @@ outputs = llm.generate(prompts, sampling_params)
|
|||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
|
@ -93,7 +93,7 @@ class VLLMWorker(BaseModelWorker):
|
|||
request_id = params.pop("request_id")
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = params.get("top_k", -1.0)
|
||||
top_k = params.get("top_k", -1)
|
||||
presence_penalty = float(params.get("presence_penalty", 0.0))
|
||||
frequency_penalty = float(params.get("frequency_penalty", 0.0))
|
||||
max_new_tokens = params.get("max_new_tokens", 256)
|
||||
|
|
|
|||
|
|
@ -13,9 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
|
||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
|
||||
__all__ = [
|
||||
"IPEXLLMAsyncLLMEngine",
|
||||
"IPEXLLMLLMEngine",
|
||||
"IPEXLLMClass",
|
||||
"run_mp_engine",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -19,9 +19,12 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.utils import Counter
|
||||
from vllm.config import EngineConfig
|
||||
from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.engine.metrics import StatLoggerBase
|
||||
from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
import signal
|
||||
|
||||
|
||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
|
@ -32,6 +35,7 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
|||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[EngineConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
|
|
@ -40,7 +44,9 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
|||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
return super().from_engine_args(engine_args, start_engine_loop, usage_context, stat_loggers)
|
||||
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||
|
||||
|
||||
class IPEXLLMClass(LLM):
|
||||
|
|
@ -117,3 +123,27 @@ class IPEXLLMLLMEngine(LLMEngine):
|
|||
# Create the engine configs.
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
return super().from_engine_args(engine_args, usage_context, stat_loggers)
|
||||
|
||||
|
||||
class IPEXLLMMQLLMEngine(MQLLMEngine):
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
return super().from_engine_args(engine_args, usage_context, ipc_path)
|
||||
|
||||
|
||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||
ipc_path: str, load_in_low_bit: str):
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||
usage_context=usage_context,
|
||||
ipc_path=ipc_path,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
engine.start()
|
||||
|
|
|
|||
|
|
@ -1,25 +1,34 @@
|
|||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from multiprocessing import Process
|
||||
from typing import AsyncIterator, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from prometheus_client import make_asgi_app
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from ipex_llm.vllm.xpu.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
||||
|
|
@ -28,154 +37,269 @@ from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
|||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||
from ipex_llm.vllm.xpu.entrypoints.openai.rpc.server import run_rpc_server
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
async_engine_client: AsyncEngineClient
|
||||
engine_args: AsyncEngineArgs
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_embedding: OpenAIServingEmbedding
|
||||
openai_serving_tokenization: OpenAIServingTokenization
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
|
||||
return ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
seed=0,
|
||||
dtype="float16").embedding_mode
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
engine_client: EngineClient = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10.)
|
||||
await engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
|
||||
yield
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[EngineClient]:
|
||||
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
|
||||
yield engine
|
||||
|
||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||
# TODO: support embedding model via RPC.
|
||||
if (model_is_embedding(args.model, args.trust_remote_code)
|
||||
or args.disable_frontend_multiprocessing):
|
||||
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER,
|
||||
load_in_low_bit=args.load_in_low_bit)
|
||||
yield async_engine_client
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
load_in_low_bit: str = 'sym_int4',
|
||||
) -> AsyncIterator[EngineClient]:
|
||||
"""
|
||||
Create EngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
engine_args=engine_args,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
port = get_open_port(envs.VLLM_RPC_PORT)
|
||||
load_in_low_bit = args.load_in_low_bit
|
||||
rpc_server_process = Process(target=run_rpc_server,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
port, load_in_low_bit))
|
||||
rpc_server_process.start()
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
# Note: global TemporaryDirectory will be automatically
|
||||
# cleaned up upon exit.
|
||||
global prometheus_multiproc_dir
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ[
|
||||
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
else:
|
||||
logger.warning(
|
||||
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
|
||||
"This directory must be wiped between vLLM runs or "
|
||||
"you will find inaccurate metrics. Unset the variable "
|
||||
"and vLLM will properly handle cleanup.")
|
||||
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
async_engine_client = AsyncEngineRPCClient(port)
|
||||
await async_engine_client.setup()
|
||||
# Select random path for IPC.
|
||||
ipc_path = get_open_zmq_ipc_path()
|
||||
logger.info("Multiprocessing frontend to use %s for IPC Path.",
|
||||
ipc_path)
|
||||
|
||||
# Start RPCServer in separate process (holds the LLMEngine).
|
||||
# the current process might have CUDA context,
|
||||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path,
|
||||
load_in_low_bit))
|
||||
engine_process.start()
|
||||
logger.info("Started engine process with PID %d", engine_process.pid)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
|
||||
|
||||
try:
|
||||
yield async_engine_client
|
||||
while True:
|
||||
try:
|
||||
await mp_engine_client.setup()
|
||||
break
|
||||
except TimeoutError:
|
||||
if not engine_process.is_alive():
|
||||
raise RuntimeError(
|
||||
"Engine process failed to start") from None
|
||||
|
||||
yield mp_engine_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
rpc_server_process.terminate()
|
||||
engine_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
async_engine_client.close()
|
||||
mp_engine_client.close()
|
||||
|
||||
# Wait for server process to join
|
||||
rpc_server_process.join()
|
||||
# Wait for engine process to join
|
||||
engine_process.join(4)
|
||||
if engine_process.exitcode is None:
|
||||
# Kill if taking longer than 5 seconds to stop
|
||||
engine_process.kill()
|
||||
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import multiprocess
|
||||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def mount_metrics(app: FastAPI):
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app())
|
||||
# Lazy import for prometheus multiprocessing.
|
||||
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
|
||||
# before prometheus_client is imported.
|
||||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
|
||||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
else:
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app())
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> Response:
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await async_engine_client.check_health()
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest):
|
||||
generator = await openai_serving_tokenization.create_tokenize(request)
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
assert isinstance(generator, TokenizeResponse)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest):
|
||||
generator = await openai_serving_tokenization.create_detokenize(request)
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
assert isinstance(generator, DetokenizeResponse)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models():
|
||||
models = await openai_serving_completion.show_available_models()
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
|
|
@ -188,46 +312,110 @@ async def show_version():
|
|||
@router.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
generator = await openai_serving_chat.create_chat_completion(
|
||||
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
assert isinstance(generator, ChatCompletionResponse)
|
||||
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await openai_serving_embedding.create_embedding(
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
if args.disable_fastapi_docs:
|
||||
app = FastAPI(openapi_url=None,
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
|
|
@ -243,7 +431,8 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
|
@ -275,74 +464,87 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
async_engine_client: AsyncEngineClient,
|
||||
def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
global openai_serving_chat
|
||||
global openai_serving_completion
|
||||
global openai_serving_embedding
|
||||
global openai_serving_tokenization
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
async_engine_client,
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
async_engine_client,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
async_engine_client,
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
async_engine_client,
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
async with build_async_engine_client(args) as async_engine_client:
|
||||
app = await init_app(async_engine_client, args)
|
||||
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
temp_socket.bind(("", args.port))
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
temp_socket.close()
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
|
@ -369,4 +571,4 @@ if __name__ == "__main__":
|
|||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
uvloop.run(run_server(args))
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ purposes.
|
|||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
|
|
@ -16,18 +17,55 @@ from vllm.utils import FlexibleArgumentParser
|
|||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
lora_list = []
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list") # noqa
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
if item in [None, '']: # Skip if item is None or empty string
|
||||
continue
|
||||
if '=' in item and ',' not in item: # Old format: name=path
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
else: # Assume JSON format
|
||||
try:
|
||||
lora_dict = json.loads(item)
|
||||
lora = LoRAModulePath(**lora_dict)
|
||||
lora_list.append(lora)
|
||||
except json.JSONDecodeError:
|
||||
parser.error(
|
||||
f"Invalid JSON format for --lora-modules: {item}")
|
||||
except TypeError as e:
|
||||
parser.error(
|
||||
f"Invalid fields for --lora-modules: {item} - {str(e)}"
|
||||
)
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
adapter_list = []
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list") # noqa
|
||||
|
||||
adapter_list: List[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
|
|
@ -72,8 +110,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help="LoRA module configurations in the format name=path. "
|
||||
"Multiple modules can be specified.")
|
||||
help="LoRA module configurations in either 'name=path' format"
|
||||
"or JSON format. "
|
||||
"Example (old format): 'name=path' "
|
||||
"Example (new format): "
|
||||
"'{\"name\": \"name\", \"local_path\": \"path\", "
|
||||
"\"base_model_name\": \"id\"}'")
|
||||
parser.add_argument(
|
||||
"--prompt-adapters",
|
||||
type=nullable_str,
|
||||
|
|
@ -91,8 +133,7 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
help="The role name to return if `request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
|
|
@ -139,6 +180,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
action="store_true",
|
||||
help="If specified, will run the OpenAI frontend server in the same "
|
||||
"process as the model serving engine.")
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-auto-tool-choice",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["mistral", "hermes"],
|
||||
default=None,
|
||||
help="Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--load-in-low-bit",
|
||||
type=str,
|
||||
|
|
@ -154,6 +212,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-fastapi-docs",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,221 +0,0 @@
|
|||
import asyncio
|
||||
import signal
|
||||
from typing import Any, Coroutine
|
||||
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from typing_extensions import Never
|
||||
|
||||
from vllm import AsyncEngineArgs
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
||||
|
||||
from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
RPCGenerateRequest, RPCUtilityRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncEngineRPCServer:
|
||||
|
||||
def __init__(self, async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, port: int, load_in_low_bit: str):
|
||||
# Initialize engine first.
|
||||
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
|
||||
usage_context=usage_context,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
|
||||
# Initialize context.
|
||||
self.context = zmq.asyncio.Context()
|
||||
|
||||
# Init socket for readiness state.
|
||||
self.socket = self.context.socket(zmq.constants.ROUTER)
|
||||
# Note numeric form of localhost should be used for zmq bind(),
|
||||
# see https://stackoverflow.com/a/8958414
|
||||
self.socket.bind(f"tcp://127.0.0.1:{port}")
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup all resources."""
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
|
||||
async def get_model_config(self, identity):
|
||||
"""Send the ModelConfig"""
|
||||
model_config = await self.engine.get_model_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(model_config)])
|
||||
|
||||
async def get_decoding_config(self, identity):
|
||||
"""Send the DecodingConfig"""
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(decoding_config)])
|
||||
|
||||
async def get_lora_config(self, identity):
|
||||
lora_config = await self.engine.get_lora_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(lora_config)])
|
||||
|
||||
async def get_scheduler_config(self, identity):
|
||||
"""Send the SchedulerConfig"""
|
||||
parallel_config = await self.engine.get_scheduler_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(parallel_config)])
|
||||
|
||||
async def get_parallel_config(self, identity):
|
||||
"""Send the ParallelConfig"""
|
||||
parallel_config = await self.engine.get_parallel_config()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(parallel_config)])
|
||||
|
||||
async def is_tracing_enabled(self, identity):
|
||||
"""Send the is_tracing_enabled flag"""
|
||||
tracing_flag = await self.engine.is_tracing_enabled()
|
||||
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(tracing_flag)])
|
||||
|
||||
async def do_log_stats(self, identity):
|
||||
"""Log stats and confirm success."""
|
||||
await self.engine.do_log_stats()
|
||||
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def is_server_ready(self, identity):
|
||||
"""Notify the client that we are ready."""
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def abort(self, identity, request: RPCAbortRequest):
|
||||
"""Abort request and notify the client of success."""
|
||||
# Abort the request in the llm engine.
|
||||
await self.engine.abort(request.request_id)
|
||||
|
||||
# Send confirmation to the client.
|
||||
await self.socket.send_multipart([
|
||||
identity,
|
||||
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
|
||||
])
|
||||
|
||||
async def generate(self, identity, generate_request: RPCGenerateRequest):
|
||||
try:
|
||||
results_generator = self.engine.generate(
|
||||
generate_request.inputs,
|
||||
sampling_params=generate_request.sampling_params,
|
||||
request_id=generate_request.request_id,
|
||||
lora_request=generate_request.lora_request,
|
||||
trace_headers=generate_request.trace_headers,
|
||||
prompt_adapter_request=generate_request.prompt_adapter_request)
|
||||
|
||||
async for request_output in results_generator:
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(request_output)])
|
||||
|
||||
except Exception as e:
|
||||
# Notify client of all failures
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
async def check_health(self, identity):
|
||||
try:
|
||||
await self.engine.check_health()
|
||||
await self.socket.send_multipart(
|
||||
[identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)])
|
||||
except Exception as e:
|
||||
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])
|
||||
|
||||
def _make_handler_coro(self, identity,
|
||||
message) -> Coroutine[Any, Any, Never]:
|
||||
"""Route the zmq message to the handler coroutine."""
|
||||
|
||||
request = cloudpickle.loads(message)
|
||||
|
||||
if isinstance(request, RPCGenerateRequest):
|
||||
return self.generate(identity, request)
|
||||
|
||||
elif isinstance(request, RPCAbortRequest):
|
||||
return self.abort(identity, request)
|
||||
|
||||
elif isinstance(request, RPCUtilityRequest):
|
||||
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
|
||||
return self.get_model_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
|
||||
return self.get_parallel_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
|
||||
return self.get_decoding_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
|
||||
return self.get_scheduler_config(identity)
|
||||
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
|
||||
return self.get_lora_config(identity)
|
||||
elif request == RPCUtilityRequest.DO_LOG_STATS:
|
||||
return self.do_log_stats(identity)
|
||||
elif request == RPCUtilityRequest.IS_SERVER_READY:
|
||||
return self.is_server_ready(identity)
|
||||
elif request == RPCUtilityRequest.CHECK_HEALTH:
|
||||
return self.check_health(identity)
|
||||
elif request == RPCUtilityRequest.IS_TRACING_ENABLED:
|
||||
return self.is_tracing_enabled(identity)
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCUtilityRequest type: {request}") # noqa
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown RPCRequest type: {request}") # noqa
|
||||
|
||||
async def run_server_loop(self):
|
||||
"""Inner RPC Server Loop"""
|
||||
|
||||
running_tasks = set()
|
||||
while True:
|
||||
# Wait for a request.
|
||||
identity, message = await self.socket.recv_multipart()
|
||||
|
||||
# Process the request async.
|
||||
task = asyncio.create_task(
|
||||
self._make_handler_coro(identity, message))
|
||||
|
||||
# We need to keep around a strong reference to the task,
|
||||
# to avoid the task disappearing mid-execution as running tasks
|
||||
# can be GC'ed. Below is a common "fire-and-forget" tasks
|
||||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
|
||||
running_tasks.add(task)
|
||||
task.add_done_callback(running_tasks.discard)
|
||||
|
||||
|
||||
async def run_server(server: AsyncEngineRPCServer):
|
||||
# Put the server task into the asyncio loop.
|
||||
loop = asyncio.get_running_loop()
|
||||
server_task = loop.create_task(server.run_server_loop())
|
||||
|
||||
# Interruption handling.
|
||||
def signal_handler() -> None:
|
||||
# Kill the server on interrupt / terminate
|
||||
server_task.cancel()
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("vLLM ZMQ RPC Server was interrupted.")
|
||||
finally:
|
||||
# Clean up all resources.
|
||||
server.cleanup()
|
||||
|
||||
|
||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, port: int, load_in_low_bit: str):
|
||||
server = AsyncEngineRPCServer(async_engine_args, usage_context, port, load_in_low_bit)
|
||||
asyncio.run(run_server(server))
|
||||
|
|
@ -75,14 +75,13 @@ def get_load_function(low_bit):
|
|||
_model_sample_convert()
|
||||
|
||||
# from vllm.utils import measure_device_memory
|
||||
from vllm.utils import CudaMemoryProfiler
|
||||
with CudaMemoryProfiler() as m:
|
||||
from vllm.utils import DeviceMemoryProfiler
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
multimodal_config=self.multimodal_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
|
|
|
|||
Loading…
Reference in a new issue