vLLM: Update vLLM-cpu to v0.6.6-post1 (#12728)

Update vLLM-cpu to v0.6.6-post1
This commit is contained in:
Xiangyu Tian 2025-01-22 15:03:01 +08:00 committed by GitHub
parent 78cca0a68c
commit c9b6c94a59
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 2086 additions and 410 deletions

View file

@ -16,6 +16,8 @@ RUN wget -qO /sbin/tini https://github.com/krallin/tini/releases/download/${TINI
apt-get update && \
apt-get install -y --no-install-recommends wrk patch g++ && \
pip install --pre --upgrade ipex-llm[serving] && \
apt-get install -y gcc-12 g++-12 libnuma-dev && \
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 && \
# Fix Trivy CVE Issues
pip install Jinja2==3.1.3 transformers==4.36.2 gradio==4.19.2 cryptography==42.0.4 && \
# Fix Qwen model adapter in fastchat
@ -24,10 +26,11 @@ RUN wget -qO /sbin/tini https://github.com/krallin/tini/releases/download/${TINI
# Install vllm
git clone https://github.com/vllm-project/vllm.git && \
cd ./vllm && \
git checkout v0.4.2 && \
pip install wheel packaging ninja setuptools>=49.4.0 numpy && \
git checkout v0.6.6.post1 && \
pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy && \
pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu && \
VLLM_TARGET_DEVICE=cpu python3 setup.py install
VLLM_TARGET_DEVICE=cpu python3 setup.py install && \
pip install ray
COPY ./vllm_offline_inference.py /llm/

View file

@ -693,7 +693,6 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
out_features,
mp_group,
None,
None,
optimize_lm_head,
None
)

View file

@ -749,7 +749,7 @@ class LowBitLinear(nn.Linear):
dist.inference_all_reduce(result, group=self.mp_group)
if self.bias is not None:
result += self.bias
return result
return result.to(x.dtype)
class FP16Linear(nn.Linear):

View file

@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
__all__ = [
"IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine",
"IPEXLLMClass",
"run_mp_engine",
]

View file

@ -13,18 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Union
from vllm.logger import init_logger
from typing import Dict, Optional, Any, Union, Type
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.config import VllmConfig
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
from vllm.usage.usage_lib import UsageContext
from vllm.engine.metrics import StatLoggerBase
from vllm.engine.multiprocessing.engine import MQLLMEngine
import signal
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.config import CompilationConfig
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
from vllm import envs
from vllm.v1.engine.async_llm import AsyncLLM
import os
from ipex_llm.utils.common import invalidInputError
logger = init_logger(__name__)
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
@ -35,49 +45,43 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: Optional[str] = None,
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Enable ipex-llm optimizations
engine_config = engine_args.create_engine_config()
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
invalidInputError(not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend."))
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
invalidInputError(engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
)
return engine
usage_context=usage_context, stat_loggers=stat_loggers)
class IPEXLLMAsyncV1Engine(AsyncLLM):
def __init__(self, *args, **kwargs):
print("IPEX-LLM V1 engine get started...")
super().__init__(*args, **kwargs)
@classmethod
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[VllmConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
) -> "AsyncLLM":
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)
class IPEXLLMClass(LLM):
def __init__(
self,
model: str,
@ -85,6 +89,7 @@ class IPEXLLMClass(LLM):
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
@ -92,22 +97,48 @@ class IPEXLLMClass(LLM):
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
load_in_low_bit: Optional[str] = None,
disable_async_output_proc: bool = True,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
load_in_low_bit: str = "sym_int4",
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
str(compilation_config))
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = None
engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
@ -116,16 +147,60 @@ class IPEXLLMClass(LLM):
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs,
)
self.llm_engine = IPEXLLMLLMEngine.from_engine_args(engine_args,
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
# TODO(gc): we will need to override this function
self.engine_class = self.get_engine_class()
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS,
load_in_low_bit=load_in_low_bit)
self.request_counter = Counter()
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
if envs.VLLM_USE_V1:
# Lazy import: the v1 package isn't distributed
# from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
return IPEXLLMLLMV1Engine # type: ignore
return IPEXLLMLLMEngine
# TODO(gc): implement this later...
class IPEXLLMLLMV1Engine(V1LLMEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
enable_multiprocessing: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
# TODO(gc): delete this later
print("IPEXLLM V1 Engine")
# This does not work as it is in the seperate process...
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context,
stat_loggers, enable_multiprocessing)
class IPEXLLMLLMEngine(LLMEngine):
def __init__(self, *args, **kwargs):
@ -136,35 +211,44 @@ class IPEXLLMLLMEngine(LLMEngine):
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
load_in_low_bit: Optional[str] = None,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
# TODO(gc): Delete
print("Use vLLM v0 engine")
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, stat_loggers)
# Initialize the cluster and specify the executor class.
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor
elif engine_config.parallel_config.worker_use_ray:
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutor
executor_class = RayGPUExecutor
else:
invalidInputError(engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1."))
from vllm.executor.gpu_executor import GPUExecutor
executor_class = GPUExecutor
# Create the LLM engine.
engine = cls(**engine_config.to_dict(),
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
class IPEXLLMMQLLMEngine(MQLLMEngine):
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
_ipex_llm_convert(load_in_low_bit)
return super().from_engine_args(engine_args, usage_context, ipc_path)
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, load_in_low_bit: str, engine_alive):
def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
try:
signal.signal(signal.SIGTERM, signal_handler)
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
)
return engine
ipc_path=ipc_path,
load_in_low_bit=load_in_low_bit)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False
raise e # noqa
if os.getenv("VLLM_USE_V1"):
IPEXLLMAsyncLLMEngine = IPEXLLMAsyncV1Engine

View file

@ -0,0 +1,787 @@
import asyncio
import atexit
import importlib
import inspect
import multiprocessing
import os
import re
import signal
import socket
import tempfile
import uuid
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
# from vllm.engine.async_llm_engine import AsyncLLMEngine # type: ignore
from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
# from vllm.engine.multiprocessing.engine import run_mp_engine
from ipex_llm.vllm.cpu.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
# from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10.)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
yield engine
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
load_in_low_bit: str = "sym_int4",
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
load_in_low_bit=load_in_low_bit,
engine_args=engine_args,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
if hasattr(engine_client, "shutdown"):
engine_client.shutdown()
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
ipc_path = get_open_zmq_ipc_path()
logger.debug("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
# The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path, load_in_low_bit, engine_alive))
engine_process.start()
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start."
logger.info("Started engine process with PID %d", engine_pid)
def _cleanup_ipc_path():
socket_path = ipc_path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
# Ensure we clean up the local IPC socket file on exit.
atexit.register(_cleanup_ipc_path)
# Build RPCClient, which conforms to EngineClient Protocol.
engine_config = engine_args.create_engine_config()
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client)
try:
while True:
try:
await mq_engine_client.setup()
break
except TimeoutError:
if (not engine_process.is_alive()
or not engine_alive.value):
raise RuntimeError(
"Engine process failed to start. See stack "
"trace for the root cause.") from None
yield mq_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
engine_process.terminate()
# Close all open connections to the backend
mq_engine_client.close()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter()
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion
def pooling(request: Request) -> Optional[OpenAIServingPooling]:
return request.app.state.openai_serving_pooling
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]:
return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health")
async def health(raw_request: Request) -> Response:
"""Health check."""
await engine_client(raw_request).check_health()
return Response(status_code=200)
@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = base(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
@router.get("/version")
async def show_version():
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
@router.post("/v1/chat/completions")
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API")
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
fallback_handler = pooling(raw_request)
if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/pooling")
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Score API")
generator = await handler.create_score(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/v1/score")
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
"have moved it to `/score`. Please update your client accordingly.")
return await create_score(request, raw_request)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
docs_url=None,
redoc_url=None,
lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
mount_metrics(app)
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = ErrorResponse(message=str(exc),
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST)
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
if request.method == "OPTIONS":
return await call_next(request)
url_path = request.url.path
if app.root_path and url_path.startswith(app.root_path):
url_path = url_path[len(app.root_path):]
if not url_path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
if args.enable_request_id_headers:
logger.warning(
"CAUTION: Enabling X-Request-Id headers in the API Server. "
"This can harm performance at high QPS.")
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
response = await call_next(request)
response.headers["X-Request-Id"] = request_id
return response
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
return app
def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
family = socket.AF_INET
if is_valid_ipv6_address(addr[0]):
family = socket.AF_INET6
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(addr)
return sock
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
parser.add_argument(
"--load-in-low-bit",
type=str,
default="sym_int4",
help="Low-bit quantization for IPEX-LLM models")
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))

View file

@ -1,138 +1,559 @@
import asyncio
import atexit
import importlib
import inspect
import multiprocessing
import os
import re
import signal
import socket
import tempfile
import uuid
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import Any, Set
from typing import AsyncIterator, Optional, Set, Tuple
import fastapi
import uvicorn
from fastapi import Request
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
import vllm
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from ipex_llm.vllm.cpu.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest, ErrorResponse)
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from ipex_llm.vllm.cpu.engine import IPEXLLMAsyncLLMEngine
from ipex_llm.utils.common import invalidInputError
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
logger = init_logger(__name__)
prometheus_multiproc_dir: tempfile.TemporaryDirectory
_running_tasks: Set[asyncio.Task[Any]] = set()
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
await asyncio.sleep(10.)
await engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
app = fastapi.FastAPI(lifespan=lifespan)
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[EngineClient]:
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing, args.load_in_low_bit) as engine:
yield engine
def parse_args():
parser = make_arg_parser()
parser.add_argument(
"--load-in-low-bit",
type=str,
default=None,
help="Low-bit quantization for IPEX-LLM models")
return parser.parse_args()
@asynccontextmanager
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
load_in_low_bit: str = "sym_int4",
) -> AsyncIterator[EngineClient]:
"""
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config(
UsageContext.OPENAI_API_SERVER)
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_args=engine_args,
engine_config=engine_config,
load_in_low_bit=load_in_low_bit,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
if hasattr(engine_client, "shutdown"):
engine_client.shutdown()
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global prometheus_multiproc_dir
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
os.environ[
"PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
else:
logger.warning(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
ipc_path = get_open_zmq_ipc_path()
logger.debug("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
# The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path, load_in_low_bit, engine_alive))
engine_process.start()
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start."
logger.info("Started engine process with PID %d", engine_pid)
def _cleanup_ipc_path():
socket_path = ipc_path.replace("ipc://", "")
if os.path.exists(socket_path):
os.remove(socket_path)
# Ensure we clean up the local IPC socket file on exit.
atexit.register(_cleanup_ipc_path)
# Build RPCClient, which conforms to EngineClient Protocol.
engine_config = engine_args.create_engine_config()
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client)
try:
while True:
try:
await mq_engine_client.setup()
break
except TimeoutError:
if (not engine_process.is_alive()
or not engine_alive.value):
raise RuntimeError(
"Engine process failed to start. See stack "
"trace for the root cause.") from None
yield mq_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
engine_process.terminate()
# Close all open connections to the backend
mq_engine_client.close()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter()
def mount_metrics(app: FastAPI):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
# Add prometheus asgi middleware to route /metrics requests
route = Mount("/metrics", make_asgi_app())
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
else:
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
@app.get("/health")
async def health() -> Response:
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion
def pooling(request: Request) -> Optional[OpenAIServingPooling]:
return request.app.state.openai_serving_pooling
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]:
return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health")
async def health(raw_request: Request) -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
await engine_client(raw_request).check_health()
return Response(status_code=200)
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
@router.post("/tokenize")
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/detokenize")
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = base(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
@app.get("/version")
@router.get("/version")
async def show_version():
ver = {"version": vllm.__version__}
ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver)
@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
handler = chat(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API")
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@app.post("/v1/completions")
@router.post("/v1/completions")
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
if __name__ == "__main__":
args = parse_args()
@router.post("/v1/embeddings")
@with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
fallback_handler = pooling(raw_request)
if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/pooling")
@with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Pooling API")
generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score")
@with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Score API")
generator = await handler.create_score(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/v1/score")
@with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
"have moved it to `/score`. Please update your client accordingly.")
return await create_score(request, raw_request)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!")
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
def build_app(args: Namespace) -> FastAPI:
if args.disable_fastapi_docs:
app = FastAPI(openapi_url=None,
docs_url=None,
redoc_url=None,
lifespan=lifespan)
else:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
mount_metrics(app)
app.add_middleware(
CORSMiddleware,
@ -142,18 +563,43 @@ if __name__ == "__main__":
allow_headers=args.allowed_headers,
)
token = os.environ.get("VLLM_API_KEY") or args.api_key
if token:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = ErrorResponse(message=str(exc),
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST)
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
root_path = "" if args.root_path is None else args.root_path
if not request.url.path.startswith(f"{root_path}/v1"):
if request.method == "OPTIONS":
return await call_next(request)
url_path = request.url.path
if app.root_path and url_path.startswith(app.root_path):
url_path = url_path[len(app.root_path):]
if not url_path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
if args.enable_request_id_headers:
logger.warning(
"CAUTION: Enabling X-Request-Id headers in the API Server. "
"This can harm performance at high QPS.")
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = request.headers.get(
"X-Request-Id") or uuid.uuid4().hex
response = await call_next(request)
response.headers["X-Request-Id"] = request_id
return response
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
@ -162,30 +608,145 @@ if __name__ == "__main__":
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
invalidInputError(False, (f"Invalid middleware {middleware}. "
f"Must be a function or a class."))
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
logger.info("vLLM API server version %s", vllm.__version__)
logger.info("args: %s", args)
return app
def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = IPEXLLMAsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER,
load_in_low_bit=args.load_in_low_bit,
)
openai_serving_chat = OpenAIServingChat(engine, served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model_names, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app,
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
base_model_paths,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.runner_type == "pooling" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
family = socket.AF_INET
if is_valid_ipv6_address(addr[0]):
family = socket.AF_INET6
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(addr)
return sock
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
@ -193,4 +754,28 @@ if __name__ == "__main__":
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs)
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
parser.add_argument(
"--load-in-low-bit",
type=str,
default="sym_int4",
help="Low-bit quantization for IPEX-LLM models")
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))

View file

@ -0,0 +1,277 @@
"""
This file contains the command line arguments for the vLLM's
OpenAI-compatible server. It is kept in a separate file for documentation
purposes.
"""
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser
class LoRAParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list") # noqa
lora_list: List[LoRAModulePath] = []
for item in values:
if item in [None, '']: # Skip if item is None or empty string
continue
if '=' in item and ',' not in item: # Old format: name=path
name, path = item.split('=')
lora_list.append(LoRAModulePath(name, path))
else: # Assume JSON format
try:
lora_dict = json.loads(item)
lora = LoRAModulePath(**lora_dict)
lora_list.append(lora)
except json.JSONDecodeError:
parser.error(
f"Invalid JSON format for --lora-modules: {item}")
except TypeError as e:
parser.error(
f"Invalid fields for --lora-modules: {item} - {str(e)}"
)
setattr(namespace, self.dest, lora_list)
class PromptAdapterParserAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Optional[Union[str, Sequence[str]]],
option_string: Optional[str] = None,
):
if values is None:
values = []
if isinstance(values, str):
raise TypeError("Expected values to be a list") # noqa
adapter_list: List[PromptAdapterPath] = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--uvicorn-log-level",
type=str,
default="info",
choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'],
help="log level for uvicorn")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument("--api-key",
type=nullable_str,
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): 'name=path' "
"Example (new format): "
"'{\"name\": \"name\", \"local_path\": \"path\", "
"\"base_model_name\": \"id\"}'")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument(
'--chat-template-content-format',
type=str,
default="auto",
choices=get_args(ChatTemplateContentFormatOption),
help='The format to render message content within a chat template.'
'\n\n'
'* "string" will render the content as a string. '
'Example: "Hello World"\n'
'* "openai" will render the content as a list of dictionaries, '
'similar to OpenAI schema. '
'Example: [{"type": "text", "text": "Hello world!"}]')
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=nullable_str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=nullable_str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs",
type=nullable_str,
default=None,
help="The CA certificates file")
parser.add_argument(
"--ssl-cert-reqs",
type=int,
default=int(ssl.CERT_NONE),
help="Whether client certificate is required (see stdlib ssl module's)"
)
parser.add_argument(
"--root-path",
type=nullable_str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=nullable_str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server "
"using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server "
"using app.add_middleware(). ")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When --max-logprobs is specified, represents single tokens as "
"strings of the form 'token_id:{token_id}' so that tokens that "
"are not JSON-encodable can be identified.")
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--enable-request-id-headers",
action="store_true",
help="If specified, API server will add X-Request-Id header to "
"responses. Caution: this hurts performance at high QPS.")
parser.add_argument(
"--enable-auto-tool-choice",
action="store_true",
default=False,
help="Enable auto tool choice for supported models. Use --tool-call-parser"
" to specify which parser to use")
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
parser.add_argument(
"--tool-call-parser",
type=str,
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
"--tool-parser-plugin",
default=None,
help="Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice.")
parser.add_argument(
"--tool-parser-plugin",
type=str,
default="",
help="Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument(
"--disable-fastapi-docs",
action='store_true',
default=False,
help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint"
)
parser.add_argument(
"--enable-prompt-tokens-details",
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
parser.add_argument(
"--load-in-low-bit",
type=str,
default="sym_int4",
help="Low-bit quantization for IPEX-LLM models")
return parser
def validate_parsed_serve_args(args: argparse.Namespace):
"""Quick checks for model serve args that raise prior to loading.""" # noqa
if hasattr(args, "subparser") and args.subparser != "serve":
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template(args.chat_template)
# Enable auto tool needs a tool call parser to be valid
if args.enable_auto_tool_choice and not args.tool_call_parser:
raise TypeError("Error: --enable-auto-tool-choice requires " # noqa
"--tool-call-parser")
def create_parser_for_docs() -> FlexibleArgumentParser:
parser_for_docs = FlexibleArgumentParser(
prog="-m vllm.entrypoints.openai.api_server")
return make_arg_parser(parser_for_docs)

View file

@ -0,0 +1,23 @@
from vllm.logger import init_logger
from vllm.v1.executor.ray_utils import RayWorkerWrapper
logger = init_logger(__name__)
class IPEXLLMV1Wrapper(RayWorkerWrapper):
def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
_ipex_llm_convert(load_in_low_bit=load_in_low_bit)
self.compiled_dag_cuda_device_set = False
def get_ipex_llm_v1_wrapper(load_in_low_bit):
# The reason why we not using functools.partial is that
# ray seems not work well with it.
class WrapperWithLoadBit(IPEXLLMV1Wrapper):
def __init__(self, *args, **kwargs) -> None:
super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs)
return WrapperWithLoadBit

View file

@ -0,0 +1,24 @@
from vllm.logger import init_logger
from vllm.executor.ray_utils import RayWorkerWrapper
logger = init_logger(__name__)
class IPEXLLMWrapper(RayWorkerWrapper):
def __init__(self, load_in_low_bit="sym_int4", *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
from ipex_llm.vllm.cpu.model_convert import _ipex_llm_convert
_ipex_llm_convert(load_in_low_bit=load_in_low_bit)
self.compiled_dag_cuda_device_set = False
def get_ipex_llm_wrapper(load_in_low_bit):
# The reason why we not using functools.partial is that
# ray seems not work well with it.
class WrapperWithLoadBit(IPEXLLMWrapper):
def __init__(self, *args, **kwargs) -> None:
super().__init__(load_in_low_bit=load_in_low_bit, *args, **kwargs)
# a = functools.partial(IPEXLLMWrapper, load_in_low_bit=load_in_low_bit)
return WrapperWithLoadBit

View file

@ -14,259 +14,152 @@
# limitations under the License.
#
import torch
from typing import Optional, Union
from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.utils import get_model_architecture
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention
from vllm.model_executor.models.llama import LlamaMLP, LlamaAttention, LlamaForCausalLM
from vllm.model_executor.models.qwen2 import Qwen2MLP, Qwen2Attention, Qwen2ForCausalLM
from vllm.model_executor.models.qwen import QWenMLP, QWenAttention, QWenLMHeadModel
from vllm.model_executor.models.baichuan import BaiChuanMLP, BaiChuanAttention
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention
from vllm.attention import Attention, AttentionMetadata
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
from vllm.model_executor.models.chatglm import GLMMLP, GLMAttention, ChatGLMForCausalLM
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.attention import AttentionMetadata
from vllm.config import DeviceConfig
from vllm.logger import init_logger
from vllm._C import ops
from ipex_llm.utils.common import invalidInputError
from typing import List, Optional, Tuple, Union
logger = init_logger(__name__)
from typing import Tuple
from ipex_llm.transformers.low_bit_linear import LowBitLinear
def _MLP_forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.down_proj(x)
return x
def _Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale)
output = self.o_proj(attn_output)
return output
def _QWen_Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.c_attn(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.c_proj(attn_output)
return output
def _QWen_MLP_forward(self, x):
gate_up = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x = self.c_proj(x)
return x
def _Qwen2_Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.o_proj(attn_output)
return output
def _ChatGLM_MLP_forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
return output
def _Baichuan_Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv = self.W_pack(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output = self.o_proj(attn_output)
return output
def _ChatGLM_Attention_forward(
def _sample_get_logits(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: AttentionMetadata,
lm_head: Union[VocabParallelEmbedding, LowBitLinear],
embedding_bias: Optional[torch.Tensor],
) -> torch.Tensor:
qkv = self.query_key_value(hidden_states).to(dtype=kv_cache.dtype)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output = self.dense(context_layer)
return attn_output
_REPLACED_MLP_LAYERS = {
LlamaMLP: _MLP_forward,
Qwen2MLP: _MLP_forward,
BaiChuanMLP: _MLP_forward,
# QWenMLP: _QWen_MLP_forward,
GLMMLP: _ChatGLM_MLP_forward
}
_REPLACED_ATTENTION_LAYERS = {
LlamaAttention: _Attention_forward,
Qwen2Attention: _Qwen2_Attention_forward,
# QWenAttention: _QWen_Attention_forward,
BaiChuanAttention: _Baichuan_Attention_forward,
GLMAttention: _ChatGLM_Attention_forward
}
_IPEX_LLM_SUPPORTED_MODELS = [
"LlamaForCausalLM",
"BaichuanForCausalLM",
"ChatGLMForCausalLM",
"Qwen2ForCausalLM",
]
# HINT: we do not support other types of quantization for now
# TODO: we may encounter tie-word-embedding problems
if isinstance(lm_head, VocabParallelEmbedding):
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
else:
logits = lm_head(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
logits = tensor_model_parallel_all_gather(logits)
if logits is not None:
logits = logits[:, : self.org_vocab_size]
return logits
def _model_mlp_convert():
for module, replaced_func in _REPLACED_MLP_LAYERS.items():
setattr(module, "forward", replaced_func)
def _model_attention_convert():
for module, replaced_func in _REPLACED_ATTENTION_LAYERS.items():
setattr(module, "forward", replaced_func)
def _model_sample_convert():
from vllm.model_executor.layers.logits_processor import LogitsProcessor
setattr(LogitsProcessor, "_get_logits", _sample_get_logits)
def _ipex_llm_convert(load_in_low_bit):
if load_in_low_bit is None:
return
from vllm.worker.cpu_model_runner import CPUModelRunner
import vllm.model_executor.model_loader as model_loader
from ipex_llm.vllm.cpu.ipex_llm_wrapper import get_ipex_llm_wrapper
from ipex_llm.vllm.cpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
import vllm.executor.ray_utils as ray_utils_v0
import vllm.v1.executor.ray_utils as ray_utils_v1
setattr(CPUModelRunner, "load_model", get_load_function(load_in_low_bit))
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
setattr(RotaryEmbedding, "forward", _ipex_llm_rotary_embedding_forward)
from vllm.model_executor.layers.layernorm import RMSNorm
setattr(RMSNorm, "forward", _ipex_llm_rmsnorm_forward)
def _ipex_llm_rotary_embedding_forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def _ipex_llm_rmsnorm_forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
x = x.to(dtype=self.weight.data.dtype)
if residual is not None:
residual = residual.to(dtype=self.weight.data.dtype)
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
setattr(ray_utils_v0, "RayWorkerWrapper", get_ipex_llm_wrapper(load_in_low_bit))
setattr(ray_utils_v1, "RayWorkerWrapper", get_ipex_llm_v1_wrapper(load_in_low_bit))
def get_load_function(low_bit):
def _ipex_llm_load_model(self) -> None:
model_class = get_model_architecture(self.model_config)[1]
cur_model_list = ", ".join(_IPEX_LLM_SUPPORTED_MODELS)
if low_bit != "bf16":
invalidInputError(model_class in _IPEX_LLM_SUPPORTED_MODELS,
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}.")
else:
if model_class not in _IPEX_LLM_SUPPORTED_MODELS:
logger.warning(
f"Currently IPEX-LLM vLLM convert only support {cur_model_list}."
_model_sample_convert()
# from vllm.utils import measure_device_memory
# from vllm.utils import DeviceMemoryProfiler
# with DeviceMemoryProfiler() as m:
from dataclasses import replace
new_device_config = DeviceConfig("cpu")
new_vllm_config = replace(self.vllm_config, device_config=new_device_config)
self.model = get_model(
vllm_config=new_vllm_config
)
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
return
# _model_mlp_convert()
# _model_attention_convert()
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
vision_language_config=self.vision_language_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
if "qwen" in self.vllm_config.model_config.model.lower() or \
"baichuan" in self.vllm_config.model_config.model.lower() or \
"codegeex4-all" in self.vllm_config.model_config.model.lower() or \
"chatglm" in self.vllm_config.model_config.model.lower():
self.model.apply(padding_mlp)
from ipex_llm import optimize_model
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.model_config.dtype)
import os
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
if not_convert_last_mlp is not None:
# only use to avoid nan value in last mlp forward running glm4-9b-chat
modules = ["35.mlp", "36.mlp", "37.mlp", "38.mlp", "39.mlp"]
else:
modules = None
if "minicpm" in self.vllm_config.model_config.model.lower():
modules = ["vpm", "resampler"]
# only for minicpm_2_6
if "minicpm-v" in self.vllm_config.model_config.model.lower():
from ipex_llm.transformers.models.minicpmv import merge_qkv
self.model.vpm.apply(merge_qkv)
if "internvl2" in self.vllm_config.model_config.model.lower():
modules = ["vision_model", "mlp1"]
# print(self.vllm_config.model_config.dtype)
# print("---------------------------------------")
optimize_model(self.model, low_bit=low_bit, torch_dtype=self.vllm_config.model_config.dtype,
modules_to_not_convert=modules)
self.model = self.model.to(device=self.vllm_config.device_config.device,
dtype=self.vllm_config.model_config.dtype)
# print(self.model)
# self.model_memory_usage = m.consumed_memory
# logger = init_logger(__name__)
# logger.info("Loading model weights took %.4f GB",
# self.model_memory_usage / float(2**30))
return _ipex_llm_load_model
def padding_mlp(module: torch.nn.Module):
mlp_gate_up_name = None
mlp_down_name = None
if isinstance(module, Qwen2MLP):
mlp_gate_up_name = "gate_up_proj"
mlp_down_name = "down_proj"
elif isinstance(module, GLMMLP):
mlp_gate_up_name = "dense_h_to_4h"
mlp_down_name = "dense_4h_to_h"
elif isinstance(module, BaiChuanMLP):
mlp_gate_up_name = "gate_up_proj"
mlp_down_name = "down_proj"
else:
return
hidden_size = getattr(module, mlp_down_name).output_size
# devide by rank
intermediate_size = getattr(module, mlp_down_name).input_size_per_partition
padding_size = 256
padding_intermediate_size = \
(intermediate_size + padding_size - 1) // padding_size * padding_size
if intermediate_size % padding_size == 0:
return
gate_up_weight = getattr(module, mlp_gate_up_name).weight.data
new_gate_up_weight = torch.zeros([padding_intermediate_size * 2, hidden_size],
dtype=gate_up_weight.dtype, device=gate_up_weight.device)
# merge_gate_up_weight
new_gate_up_weight[:intermediate_size, :] = gate_up_weight[:intermediate_size, :]
new_gate_up_weight[padding_intermediate_size:padding_intermediate_size+intermediate_size, :] = gate_up_weight[intermediate_size:, :] # noqa
getattr(module, mlp_gate_up_name).output_size_per_partition = padding_intermediate_size * 2
getattr(module, mlp_gate_up_name).output_size = padding_intermediate_size * 2
getattr(module, mlp_gate_up_name).weight = \
torch.nn.Parameter(new_gate_up_weight, requires_grad=False)
down_weight = getattr(module, mlp_down_name).weight.data
new_down_weight = torch.zeros([hidden_size, padding_intermediate_size],
dtype=down_weight.dtype, device=down_weight.device)
new_down_weight[:, :intermediate_size] = down_weight
getattr(module, mlp_down_name).input_size_per_partition = padding_intermediate_size
getattr(module, mlp_down_name).input_size = padding_intermediate_size
getattr(module, mlp_down_name).weight = torch.nn.Parameter(new_down_weight, requires_grad=False)