vLLM: update vLLM XPU to 0.8.3 version (#13118)

vLLM: update vLLM XPU to 0.8.3 version
This commit is contained in:
Xiangyu Tian 2025-04-30 14:40:53 +08:00 committed by GitHub
parent f66eee1d1d
commit 51b41faad7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 4608 additions and 28217 deletions

View file

@ -1,7 +1,9 @@
From d345631f78a2f33ff1ddd7d9908b288eb0afaf46 Mon Sep 17 00:00:00 2001
From: Huajun Li <huajun.li@.com>
Date: Fri, 24 May 2024 09:47:26 +0800
Subject: [PATCH 1/3] allreduce optimization with LL256 for Arc770 dGPU
From dfe1851b59df6859829b447353307b7c916ccee0 Mon Sep 17 00:00:00 2001
From: junhansh <junhan.shi@intel.com>
Date: Mon, 28 Apr 2025 23:33:11 +0800
Subject: [PATCH] oneccl for Arc770 V2025.0.0.6.7
allreduce optimization with LL256 for Arc770 dGPU
To enable this feature, please set env var:
export CCL_DG2_ALLREDUCE=1
@ -12,6 +14,15 @@ Build:
3. cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp -DCMAKE_BUILD_TYPE=MinSizeRel
4. ninja
5. ls -al src/libccl*
Changes:
optimize req_workgroup calculate
Revert "optimize req_workgroup calculate" for hang issue
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
fix_fdset_buffer_overflow_issue
---
src/CMakeLists.txt | 2 +
src/coll/coll.cpp | 30 +-
@ -20,9 +31,9 @@ Build:
src/common/env/env.cpp | 1 +
src/common/env/env.hpp | 1 +
src/common/env/vars.hpp | 1 +
src/dg2/dg2_allreduce.cpp | 642 +++++++++++++++++++++++++++++++
src/dg2/dg2_allreduce.cpp | 640 +++++++++++++++++++++++++++++++
src/dg2/dg2_allreduce.hpp | 13 +
9 files changed, 693 insertions(+), 3 deletions(-)
9 files changed, 691 insertions(+), 3 deletions(-)
create mode 100644 src/dg2/dg2_allreduce.cpp
create mode 100644 src/dg2/dg2_allreduce.hpp
@ -163,10 +174,10 @@ index 73dcf77..84ab518 100644
constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE";
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
new file mode 100644
index 0000000..15ace74
index 0000000..73e114b
--- /dev/null
+++ b/src/dg2/dg2_allreduce.cpp
@@ -0,0 +1,642 @@
@@ -0,0 +1,640 @@
+#include <fcntl.h>
+#include <unistd.h>
+#include <sys/un.h>
@ -178,7 +189,7 @@ index 0000000..15ace74
+#include <drm/drm.h>
+
+#include <mpi.h>
+
+#include <poll.h>
+#include <vector>
+#include <sstream>
+#include <iostream>
@ -315,7 +326,6 @@ index 0000000..15ace74
+
+static void *thread_func(void *arg)
+{
+ fd_set fds;
+ int count = 0;
+ char sock_path[64];
+ int peer_buf_fd = 0;
@ -323,6 +333,10 @@ index 0000000..15ace74
+
+ snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, rank, 0xa770);
+ int srv_fd = srv_sock(sock_path);
+ if (srv_fd < 0) {
+ perror("srv_sock failed");
+ return nullptr;
+ }
+
+ //std::cout << "-----> srv_fd of " << sock_path << " : " << srv_fd << "\n";
+
@ -331,35 +345,30 @@ index 0000000..15ace74
+ ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
+ ze_device_handle_t ze_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
+
+ FD_ZERO(&fds);
+ FD_SET(srv_fd, &fds);
+ struct pollfd pfd = {
+ .fd = srv_fd,
+ .events = POLL_IN,
+ .revents = 0
+ };
+ while (++count < world_size) {
+ int ret = select(srv_fd + 1, &fds, NULL, NULL, NULL);
+ switch (ret) {
+ case 1:
+ {
+ int peer_rank;
+ void *peer_buf;
+ int ret = poll(&pfd, 1, -1);
+ if (ret <= 0) {
+ std::cerr << "poll failed: " << strerror(errno) << "\n";
+ break;
+ }
+
+ int conn_fd = accept(srv_fd, NULL, 0);
+ ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
+ if (pfd.revents & POLL_IN) {
+ int peer_rank;
+ void *peer_buf = nullptr;
+
+ ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
+ zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED /* cached allocation */, &peer_buf);
+ int conn_fd = accept(srv_fd, NULL, 0);
+ ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
+ ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
+ zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED, &peer_buf);
+
+ peer_bufs[peer_rank] = peer_buf;
+ //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
+
+ if (conn_fd > 0) close(conn_fd);
+
+ break;
+ }
+ case 0:
+ case -1:
+ std::cout << "srv_fd select() failed" << "\n";
+ break;
+ default:
+ break;
+ peer_bufs[peer_rank] = peer_buf;
+ //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
+ if (conn_fd > 0) close(conn_fd);
+ }
+ }
+
@ -831,105 +840,3 @@ index 0000000..0506445
--
2.34.1
From 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa Mon Sep 17 00:00:00 2001
From: Huajun Li <huajun.li@.com>
Date: Fri, 7 Mar 2025 15:15:35 +0800
Subject: [PATCH 2/3] optimize req_workgroup calculate
---
src/dg2/dg2_allreduce.cpp | 25 ++-----------------------
1 file changed, 2 insertions(+), 23 deletions(-)
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
index 15ace74..83270ae 100644
--- a/src/dg2/dg2_allreduce.cpp
+++ b/src/dg2/dg2_allreduce.cpp
@@ -527,30 +527,9 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
auto chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
auto chunk_with_pattern = sg_sz * LS_SZ; /* aligned to 256B */
- /* items will be assigned to each rank */
- auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
- auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
- auto req_subgroups = 0;
-
- if (req_workgroups >= g_sz/l_sz) {
- req_workgroups = g_sz/l_sz;
- } else {
- if (group_id == (req_workgroups - 1)) {
- req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
-
- /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
- /* Note: req_subgroups % (l_sz/sg_sz) might be 0 */
- if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
- if ((per_rank_items % (sg_sz - 1)) != 0) {
- /* FIXME: */
- req_workitems = per_rank_items % (sg_sz - 1);
- chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
- }
- }
- }
- }
+ auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
- if (group_id < req_workgroups) {
+ if (work_left > 0) {
// step 1: push data to next GPU
{
offset = base + local_world_rank * chunk_sz;
--
2.34.1
From 1c58cc9ede5ca38138a270f9e5ff59bca74f51d4 Mon Sep 17 00:00:00 2001
From: Huajun Li <huajun.li@.com>
Date: Wed, 12 Mar 2025 13:29:27 +0800
Subject: [PATCH 3/3] Revert "optimize req_workgroup calculate" for hang issue
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
---
src/dg2/dg2_allreduce.cpp | 25 +++++++++++++++++++++++--
1 file changed, 23 insertions(+), 2 deletions(-)
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
index 83270ae..15ace74 100644
--- a/src/dg2/dg2_allreduce.cpp
+++ b/src/dg2/dg2_allreduce.cpp
@@ -527,9 +527,30 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
auto chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
auto chunk_with_pattern = sg_sz * LS_SZ; /* aligned to 256B */
- auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
+ /* items will be assigned to each rank */
+ auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
+ auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
+ auto req_subgroups = 0;
+
+ if (req_workgroups >= g_sz/l_sz) {
+ req_workgroups = g_sz/l_sz;
+ } else {
+ if (group_id == (req_workgroups - 1)) {
+ req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
+
+ /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
+ /* Note: req_subgroups % (l_sz/sg_sz) might be 0 */
+ if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
+ if ((per_rank_items % (sg_sz - 1)) != 0) {
+ /* FIXME: */
+ req_workitems = per_rank_items % (sg_sz - 1);
+ chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
+ }
+ }
+ }
+ }
- if (work_left > 0) {
+ if (group_id < req_workgroups) {
// step 1: push data to next GPU
{
offset = base + local_world_rank * chunk_sz;
--
2.34.1

View file

@ -54,19 +54,20 @@ RUN set -eux && \
#
# Install Intel PyTorch extension for LLM inference
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
#
# Build torch-ccl
mkdir -p /build && \
cd /build && \
git clone https://github.com/intel/torch-ccl.git && \
cd torch-ccl && \
git checkout ccl_torch2.5.0+xpu && \
git checkout ccl_torch2.6.0+xpu && \
git submodule sync && \
git submodule update --init --recursive && \
# This patch will enable build torch-ccl with pytorch 2.6 environment
git apply /tmp/ccl_torch.patch && \
# git apply /tmp/ccl_torch.patch && \
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel && \
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl
# Build oneCCL
pip install ninja && \
cd /build/ && \
@ -85,7 +86,7 @@ RUN set -eux && \
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
# Copy the built torch-ccl package from the build stage
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
COPY --from=build /llm/ /llm/
COPY --from=build /build/oneCCL/build/src/libccl.so.1.0 /opt/intel/1ccl-wks/lib/
COPY --from=build /build/oneCCL/build/src/libccl.so.1 /opt/intel/1ccl-wks/lib/
@ -144,9 +145,10 @@ RUN set -eux && \
# Install vllm dependencies
pip install --upgrade fastapi && \
pip install --upgrade "uvicorn[standard]" && \
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
#
# Install torch-ccl
pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
pip install /opt/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl && \
#
apt-get update && \
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
@ -168,21 +170,19 @@ RUN set -eux && \
mkdir -p /llm && \
cd /llm && \
rm -rf /tmp/neo && \
# Install intel_extension_for_pytorch
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \
pip uninstall -y oneccl oneccl-devel && \
pip install intel-opencl-rt==2025.0.2 intel-openmp==2025.0.2 && \
#
# Install vllm
git clone -b v0.6.6.post1 https://github.com/vllm-project/vllm /llm/vllm && \
git clone -b v0.8.3 https://github.com/vllm-project/vllm /llm/vllm && \
cd /llm/vllm && \
git apply /llm/vllm_for_multi_arc.patch && \
pip install setuptools-scm && \
pip install setuptools-scm==8.2.0 setuptools==78.1.0 && \
pip install --upgrade cmake && \
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
pip install -v -r requirements/xpu.txt && \
VLLM_TARGET_DEVICE=xpu python setup.py install && \
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
pip uninstall -y oneccl oneccl-devel && \
rm -rf /llm/vllm_for_multi_arc.patch && \
pip install mpi4py fastapi uvicorn openai && \
pip install ray
pip install ray numba
WORKDIR /llm/

View file

@ -32,6 +32,9 @@ export TORCH_LLM_ALLREDUCE=0
export CCL_SAME_STREAM=1
export CCL_BLOCKING_WAIT=0
export VLLM_USE_V1=0
export IPEX_LLM_LOWBIT=$LOAD_IN_LOW_BIT
source /opt/intel/1ccl-wks/setvars.sh
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \

File diff suppressed because one or more lines are too long

View file

@ -782,6 +782,9 @@ export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
export TORCH_LLM_ALLREDUCE=0
export VLLM_USE_V1=0
export IPEX_LLM_LOWBIT=fp8
source /opt/intel/1ccl-wks/setvars.sh
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
@ -793,7 +796,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--device xpu \
--dtype float16 \
--enforce-eager \
--load-in-low-bit fp8 \
--load-in-low-bit $IPEX_LLM_LOWBIT \
--max-model-len 2048 \
--max-num-batched-tokens 4000 \
--api-key <your-api-key> \

View file

@ -50,9 +50,14 @@ pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorc
pip install setuptools-scm
pip install --upgrade cmake
# cd to your workdir
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
git clone -b 0.8.3 https://github.com/analytics-zoo/vllm.git
cd vllm
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
pip install setuptools-scm==8.2.0 setuptools==78.1.0
pip install --upgrade cmake
pip install -v -r requirements/xpu.txt
VLLM_TARGET_DEVICE=xpu python setup.py install
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
pip uninstall -y oneccl oneccl-devel
# For Qwen model support
pip install transformers_stream_generator einops tiktoken
pip install ray
@ -93,6 +98,8 @@ For vLLM, you can start the service using the following command:
model="YOUR_MODEL_PATH"
served_model_name="YOUR_MODEL_NAME"
export VLLM_RPC_TIMEOUT=100000
export VLLM_USE_V1=0
export IPEX_LLM_LOWBIT=fp8
# You may need to adjust the value of
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
@ -107,7 +114,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
--device xpu \
--dtype float16 \
--enforce-eager \
--load-in-low-bit sym_int4 \
--load-in-low-bit $IPEX_LLM_LOWBIT \
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \

View file

@ -150,12 +150,13 @@ def is_linear_module(module):
if _VLLM_VERSION is None:
_VLLM_VERSION = get_package_version('vllm')
from vllm.model_executor.layers.linear import (
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear, ReplicatedLinear
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
VLLM_LINEAR_LIST = [
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
MergedColumnParallelLinear,
MergedColumnParallelLinear, ReplicatedLinear,
]
if 'xpu' in _VLLM_VERSION:
VLLM_LINEAR_LIST.append(ParallelLMHead)

View file

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine, IPEXLLMAsyncV1Engine, IPEXLLMLLMV1Engine
__all__ = [
"IPEXLLMAsyncLLMEngine",
"IPEXLLMLLMEngine",
"IPEXLLMClass",
"IPEXLLMAsyncV1Engine",
"IPEXLLMLLMV1Engine",
"run_mp_engine",
]

View file

@ -38,6 +38,8 @@ logger = init_logger(__name__)
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
_is_converted = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -53,13 +55,39 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "AsyncLLMEngine":
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_vllm_config(
vllm_config=vllm_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_requests=disable_log_requests,
disable_log_stats=disable_log_stats,
)
class IPEXLLMAsyncV1Engine(AsyncLLM):
_is_converted = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -74,13 +102,39 @@ class IPEXLLMAsyncV1Engine(AsyncLLM):
load_in_low_bit: str = "sym_int4",
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, # noqa
) -> "AsyncLLM":
_ipex_llm_convert(load_in_low_bit)
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context, stat_loggers=stat_loggers)
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
disable_log_requests: bool = False,
disable_log_stats: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "AsyncLLM":
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_vllm_config(
vllm_config=vllm_config,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_requests=disable_log_requests,
disable_log_stats=disable_log_stats,
)
class IPEXLLMClass(LLM):
def __init__(
self,
model: str,
@ -94,20 +148,20 @@ class IPEXLLMClass(LLM):
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
seed: Optional[int] = None,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = True,
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
disable_async_output_proc: bool = False,
hf_overrides: Optional[HfOverrides]=None,
mm_processor_kwargs: Optional[dict[str, Any]]=None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
compilation_config: Optional[Union[int, dict[str, Any]]]=None,
load_in_low_bit: str = "sym_int4",
**kwargs,
) -> None:
@ -120,6 +174,13 @@ class IPEXLLMClass(LLM):
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
if "worker_cls" in kwargs:
worker_cls = kwargs["worker_cls"]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if compilation_config is not None:
if isinstance(compilation_config, (int, dict)):
compilation_config_instance = CompilationConfig.from_cli(
@ -159,11 +220,13 @@ class IPEXLLMClass(LLM):
# Logic to switch between engines is done at runtime instead of import
# to avoid import order issues
self.engine_class = self.get_engine_class()
# print("!!! ", load_in_low_bit)
self.llm_engine = self.engine_class.from_engine_args(
engine_args, usage_context=UsageContext.LLM_CLASS,
load_in_low_bit=load_in_low_bit)
self.request_counter = Counter()
self.default_sampling_params: Union[dict[str, Any], None] = None
@staticmethod
def get_engine_class() -> Type[LLMEngine]:
@ -173,6 +236,8 @@ class IPEXLLMClass(LLM):
class IPEXLLMLLMV1Engine(V1LLMEngine):
_is_converted = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -188,14 +253,37 @@ class IPEXLLMLLMV1Engine(V1LLMEngine):
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_engine_args(engine_args,
usage_context,
stat_loggers,
enable_multiprocessing)
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
disable_log_stats: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_stats=disable_log_stats
)
class IPEXLLMLLMEngine(LLMEngine):
_is_converted = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -209,33 +297,89 @@ class IPEXLLMLLMEngine(LLMEngine):
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
_ipex_llm_convert(load_in_low_bit)
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_engine_args(engine_args, usage_context, stat_loggers)
@classmethod
def from_vllm_config(
cls,
vllm_config: VllmConfig,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
disable_log_stats: bool = False,
load_in_low_bit: str = "sym_int4",
) -> "LLMEngine":
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
stat_loggers=stat_loggers,
disable_log_stats=disable_log_stats
)
class IPEXLLMMQLLMEngine(MQLLMEngine):
_is_converted = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
_ipex_llm_convert(load_in_low_bit)
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_engine_args(engine_args, usage_context, ipc_path)
@classmethod
def from_vllm_config(cls, vllm_config: VllmConfig,
usage_context: UsageContext,
disable_log_requests: bool, disable_log_stats: bool,
ipc_path: str, load_in_low_bit: str) -> "MQLLMEngine":
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, load_in_low_bit: str, engine_alive):
if not cls._is_converted:
_ipex_llm_convert(load_in_low_bit)
cls._is_converted = True
return super().from_vllm_config(
vllm_config=vllm_config,
ipc_path=ipc_path,
usage_context=usage_context,
disable_log_requests=disable_log_requests,
disable_log_stats=disable_log_stats,
)
def signal_handler(*_) -> None:
# Interrupt server on sigterm
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
ipc_path: str, disable_log_stats: bool,
disable_log_requests: bool, load_in_low_bit, engine_alive):
try:
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
engine = IPEXLLMMQLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_stats=disable_log_stats,
disable_log_requests=disable_log_requests,
load_in_low_bit=load_in_low_bit,
ipc_path=ipc_path)
signal.signal(signal.SIGTERM, signal_handler)
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path,
load_in_low_bit=load_in_low_bit)
engine.start()
except BaseException as e:
logger.exception(e)
engine_alive.value = False

View file

@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import atexit
import gc
import importlib
import inspect
import multiprocessing
@ -10,16 +13,18 @@ import socket
import tempfile
import uuid
from argparse import Namespace
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
from typing import Annotated, Optional, Union
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
@ -27,17 +32,17 @@ from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine # type: ignore
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from ipex_llm.vllm.xpu.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import (load_chat_template,
resolve_hf_chat_template,
resolve_mistral_chat_template)
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
# from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@ -46,33 +51,46 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
LoadLoRAAdapterRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
TranscriptionRequest,
TranscriptionResponse,
UnloadLoRAAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
with_cancellation)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
is_valid_ipv6_address, set_ulimit)
from vllm.version import __version__ as VLLM_VERSION
@ -83,7 +101,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
_running_tasks: set[asyncio.Task] = set()
@asynccontextmanager
@ -102,6 +120,11 @@ async def lifespan(app: FastAPI):
task.add_done_callback(_running_tasks.remove)
else:
task = None
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
gc.collect()
gc.freeze()
try:
yield
finally:
@ -139,24 +162,49 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed.
"""
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
# Create the EngineConfig (determines if we can use V1).
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
# V1 AsyncLLM.
if envs.VLLM_USE_V1:
if disable_frontend_multiprocessing:
logger.warning(
"V1 is enabled, but got --disable-frontend-multiprocessing. "
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncV1Engine as AsyncLLM
async_llm: Optional[AsyncLLM] = None
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
load_in_low_bit=load_in_low_bit)
yield async_llm
finally:
if async_llm:
async_llm.shutdown()
# V0 AsyncLLM.
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
or disable_frontend_multiprocessing):
engine_client: Optional[EngineClient] = None
try:
# When starting this, we are actually starting with the V1Engine
# Here we are doing a classification, we will need to do this in IPEX-LLM
engine_client = AsyncLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.OPENAI_API_SERVER,
engine_client = AsyncLLMEngine.from_vllm_config(
vllm_config=vllm_config,
usage_context=usage_context,
disable_log_requests=engine_args.disable_log_requests,
disable_log_stats=engine_args.disable_log_stats,
load_in_low_bit=load_in_low_bit)
yield engine_client
finally:
if engine_client and hasattr(engine_client, "shutdown"):
engine_client.shutdown()
# Otherwise, use the multiprocessing AsyncLLMEngine.
# V0MQLLMEngine.
else:
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
# Make TemporaryDirectory for prometheus multiprocessing
@ -183,14 +231,18 @@ async def build_async_engine_client_from_engine_args(
# so we need to spawn a new process
context = multiprocessing.get_context("spawn")
# Ensure we can serialize transformer config before spawning
maybe_register_config_serialize_by_value()
# The Process can raise an exception during startup, which may
# not actually result in an exitcode being reported. As a result
# we use a shared variable to communicate the information.
engine_alive = multiprocessing.Value('b', True, lock=False)
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path, load_in_low_bit, engine_alive))
engine_process = context.Process(
target=run_mp_engine,
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
engine_args.disable_log_stats,
engine_args.disable_log_requests, load_in_low_bit, engine_alive))
engine_process.start()
engine_pid = engine_process.pid
assert engine_pid is not None, "Engine process failed to start."
@ -205,8 +257,7 @@ async def build_async_engine_client_from_engine_args(
atexit.register(_cleanup_ipc_path)
# Build RPCClient, which conforms to EngineClient Protocol.
engine_config = engine_args.create_engine_config()
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
engine_pid)
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_client)
@ -244,6 +295,43 @@ async def build_async_engine_client_from_engine_args(
multiprocess.mark_process_dead(engine_process.pid)
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)
save_dict = {}
import os
flag = os.getenv("VLLM_LOG_OUTPUT", None)
async def stream_generator(generator, request, request_id):
async for chunk in generator:
if request_id not in save_dict:
save_dict[request_id] = ""
import json
try:
data = chunk.strip()
if data.startswith('data: '):
data = data[len('data: '):]
else:
yield chunk
json_data = json.loads(data)
if 'choices' in json_data and len(json_data['choices']) > 0:
choice = json_data['choices'][0]
if 'delta' in choice:
save_dict[request_id] += choice["delta"]["content"]
elif 'text' in choice:
save_dict[request_id] += choice["text"]
except json.JSONDecodeError:
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
pass # Done
yield chunk
router = APIRouter()
@ -254,6 +342,7 @@ def mount_metrics(app: FastAPI):
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import (CollectorRegistry, make_asgi_app,
multiprocess)
from prometheus_fastapi_instrumentator import Instrumentator
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
if prometheus_multiproc_dir_path is not None:
@ -261,6 +350,16 @@ def mount_metrics(app: FastAPI):
prometheus_multiproc_dir_path)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
],
registry=registry,
).add().instrument(app).expose(app)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
@ -298,7 +397,11 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]:
def score(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
def rerank(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
@ -306,6 +409,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@ -317,7 +424,31 @@ async def health(raw_request: Request) -> Response:
return Response(status_code=200)
@router.post("/tokenize")
@router.get("/load")
async def get_server_load_metrics(request: Request):
# This endpoint returns the current server load metrics.
# It tracks requests utilizing the GPU from the following routes:
# - /v1/chat/completions
# - /v1/completions
# - /v1/audio/transcriptions
# - /v1/embeddings
# - /pooling
# - /score
# - /v1/score
# - /rerank
# - /v1/rerank
# - /v2/rerank
return JSONResponse(
content={'server_load': request.app.state.server_load_metrics})
@router.api_route("/ping", methods=["GET", "POST"])
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
@ -332,7 +463,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never(generator)
@router.post("/detokenize")
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
@ -361,35 +492,10 @@ async def show_version():
return JSONResponse(content=ver)
save_dict = {}
import os
flag = os.getenv("VLLM_LOG_OUTPUT", None)
async def stream_generator(generator, request, request_id):
async for chunk in generator:
if request_id not in save_dict:
save_dict[request_id] = ""
import json
try:
data = chunk.strip()
if data.startswith('data: '):
data = data[len('data: '):]
else:
yield chunk
json_data = json.loads(data)
if 'choices' in json_data and len(json_data['choices']) > 0:
choice = json_data['choices'][0]
if 'delta' in choice:
save_dict[request_id] += choice["delta"]["content"]
elif 'text' in choice:
save_dict[request_id] += choice["text"]
except json.JSONDecodeError:
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
pass # Done
yield chunk
@router.post("/v1/chat/completions")
@router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
@ -401,7 +507,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
request_id = "chatcmpl-" \
f"{handler._base_request_id(raw_request, request.request_id)}"
print(f"First received request_id: {request_id}, request: {request}")
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
@ -418,8 +523,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions")
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
if handler is None:
@ -438,14 +544,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if flag is not None:
print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
return JSONResponse(content=generator.model_dump())
if flag is not None:
return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings")
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
@ -460,6 +567,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
@ -488,8 +597,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/pooling")
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
@ -506,8 +616,9 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never(generator)
@router.post("/score")
@router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
@ -524,8 +635,9 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score")
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
@ -534,6 +646,160 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/v1/audio/transcriptions")
@with_cancellation
@load_aware_call
async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()],
raw_request: Request):
handler = transcription(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Transcriptions API")
audio_data = await request.file.read()
generator = await handler.create_transcription(audio_data, request,
raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TranscriptionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Rerank (Score) API")
generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client "
"accordingly. (Note: Conforms to JinaAI rerank API)")
return await do_rerank(request, raw_request)
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
},
"embed": {
"messages": (EmbeddingChatRequest, create_embedding),
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (RerankRequest, do_rerank)
},
"rerank": {
"default": (RerankRequest, do_rerank)
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
"classify": {
"messages": (PoolingChatRequest, create_pooling),
"default": (PoolingCompletionRequest, create_pooling),
},
}
if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(raw_request: Request):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
"""
device = None
device_str = raw_request.query_params.get("device")
if device_str is not None:
device = Device[device_str.upper()]
logger.info("Resetting prefix cache with specific %s...", str(device))
await engine_client(raw_request).reset_prefix_cache(device)
return Response(status_code=200)
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/wake_up")
async def wake_up(raw_request: Request):
tags = raw_request.query_params.getlist("tags")
if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.get("/is_sleeping")
async def is_sleeping(raw_request: Request):
logger.info("check whether the engine is sleeping")
is_sleeping = await engine_client(raw_request).is_sleeping()
return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
async def invocations(raw_request: Request):
"""
For SageMaker, routes requests to other handlers based on model `task`.
"""
body = await raw_request.json()
task = raw_request.app.state.task
if task not in TASK_HANDLERS:
raise HTTPException(
status_code=400,
detail=f"Unsupported task: '{task}' for '/invocations'. "
f"Expected one of {set(TASK_HANDLERS.keys())}")
handler_config = TASK_HANDLERS[task]
if "messages" in body:
request_model, handler = handler_config["messages"]
else:
request_model, handler = handler_config["default"]
# this is required since we lose the FastAPI automatic casting
request = request_model.model_validate(body)
return await handler(request, raw_request)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
@ -556,32 +822,30 @@ if envs.VLLM_TORCH_PROFILER_DIR:
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
@router.post("/v1/load_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoRAAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
handler = models(raw_request)
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
@router.post("/v1/unload_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
handler = models(raw_request)
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@ -615,7 +879,8 @@ def build_app(args: Namespace) -> FastAPI:
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
if token := envs.VLLM_API_KEY or args.api_key:
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
if token := args.api_key or envs.VLLM_API_KEY:
@app.middleware("http")
async def authentication(request: Request, call_next):
@ -644,11 +909,26 @@ def build_app(args: Namespace) -> FastAPI:
response.headers["X-Request-Id"] = request_id
return response
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "
"avoided in production.")
@app.middleware("http")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [
section async for section in response.body_iterator
]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logger.info("response_body={%s}", response_body[0].decode())
return response
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
@ -658,7 +938,7 @@ def build_app(args: Namespace) -> FastAPI:
return app
def init_app_state(
async def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
@ -683,15 +963,36 @@ def init_app_state(
state.log_stats = not args.disable_log_stats
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
if resolved_chat_template is not None:
# Get the tokenizer to check official template
tokenizer = await engine_client.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
# The warning is logged in resolve_mistral_chat_template.
resolved_chat_template = resolve_mistral_chat_template(
chat_template=resolved_chat_template)
else:
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=model_config.trust_remote_code)
if hf_chat_template != resolved_chat_template:
logger.warning(
"Using supplied chat template: %s\n"
"It is different from official chat template '%s'. "
"This discrepancy may lead to performance degradation.",
resolved_chat_template, args.model)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
# TODO: The chat template is now broken for lora adapters :(
await state.openai_serving_models.init_static_loras()
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@ -703,6 +1004,8 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
enable_reasoning=args.enable_reasoning,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
@ -728,7 +1031,13 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores(
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
@ -742,92 +1051,26 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.openai_serving_transcription = OpenAIServingTranscription(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
state.task = model_config.task
# if args.served_model_name is not None:
# served_model_names = args.served_model_name
# else:
# served_model_names = [args.model]
# if args.disable_log_requests:
# request_logger = None
# else:
# request_logger = RequestLogger(max_log_len=args.max_log_len)
# base_model_paths = [
# BaseModelPath(name=name, model_path=args.model)
# for name in served_model_names
# ]
# state.engine_client = engine_client
# state.log_stats = not args.disable_log_stats
# resolved_chat_template = load_chat_template(args.chat_template)
# logger.info("Using supplied chat template:\n%s", resolved_chat_template)
# state.openai_serving_chat = OpenAIServingChat(
# engine_client,
# model_config,
# base_model_paths,
# args.response_role,
# lora_modules=args.lora_modules,
# prompt_adapters=args.prompt_adapters,
# request_logger=request_logger,
# chat_template=resolved_chat_template,
# chat_template_content_format=args.chat_template_content_format,
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
# enable_auto_tools=args.enable_auto_tool_choice,
# tool_parser=args.tool_call_parser,
# enable_prompt_tokens_details=args.enable_prompt_tokens_details,
# ) if model_config.runner_type == "generate" else None
# state.openai_serving_completion = OpenAIServingCompletion(
# engine_client,
# model_config,
# base_model_paths,
# lora_modules=args.lora_modules,
# prompt_adapters=args.prompt_adapters,
# request_logger=request_logger,
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
# ) if model_config.runner_type == "generate" else None
# state.openai_serving_pooling = OpenAIServingPooling(
# engine_client,
# model_config,
# base_model_paths,
# request_logger=request_logger,
# chat_template=resolved_chat_template,
# chat_template_content_format=args.chat_template_content_format,
# ) if model_config.runner_type == "pooling" else None
# state.openai_serving_embedding = OpenAIServingEmbedding(
# engine_client,
# model_config,
# base_model_paths,
# request_logger=request_logger,
# chat_template=resolved_chat_template,
# chat_template_content_format=args.chat_template_content_format,
# ) if model_config.task == "embed" else None
# state.openai_serving_scores = OpenAIServingScores(
# engine_client,
# model_config,
# base_model_paths,
# request_logger=request_logger
# ) if model_config.task == "score" else None
# state.openai_serving_tokenization = OpenAIServingTokenization(
# engine_client,
# model_config,
# base_model_paths,
# lora_modules=args.lora_modules,
# request_logger=request_logger,
# chat_template=resolved_chat_template,
# chat_template_content_format=args.chat_template_content_format,
# )
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
family = socket.AF_INET
if is_valid_ipv6_address(addr[0]):
family = socket.AF_INET6
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(addr)
return sock
@ -840,11 +1083,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
valide_tool_parses = ToolParserManager.tool_parsers.keys()
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if args.enable_auto_tool_choice \
and args.tool_call_parser not in valide_tool_parses:
and args.tool_call_parser not in valid_tool_parses:
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
f"(chose from {{ {','.join(valide_tool_parses)} }})")
f"(chose from {{ {','.join(valid_tool_parses)} }})")
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
if args.enable_reasoning \
and args.reasoning_parser not in valid_reasoning_parses:
raise KeyError(
f"invalid reasoning parser: {args.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
@ -866,13 +1116,28 @@ async def run_server(args, **uvicorn_kwargs) -> None:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
await init_app_state(engine_client, model_config, app.state, args)
def _listen_addr(a: str) -> str:
if is_valid_ipv6_address(a):
return '[' + a + ']'
return a or "0.0.0.0"
is_ssl = args.ssl_keyfile and args.ssl_certfile
logger.info("Starting vLLM API server on http%s://%s:%d",
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1])
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
@ -882,16 +1147,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
)
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
sock.close()
try:
await shutdown_task
finally:
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
"instead of `vllm.entrypoints.openai.api_server` to start the API server")
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)

View file

@ -48,7 +48,7 @@ def _sample_get_logits(
logits = lm_head(hidden_states)
if embedding_bias is not None:
logits += embedding_bias
if self.use_gather:
if self.use_all_gather:
logits = tensor_model_parallel_gather(logits)
else:
logits = tensor_model_parallel_all_gather(logits)
@ -63,6 +63,8 @@ def _model_sample_convert():
def _ipex_llm_convert(load_in_low_bit):
# import pdb
# pdb.set_trace()
from vllm.worker.xpu_model_runner import XPUModelRunner
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
@ -99,7 +101,8 @@ def get_load_function(low_bit):
"codegeex4-all" in self.vllm_config.model_config.model.lower() or
"chatglm" in self.vllm_config.model_config.model.lower()) and \
"gptq" not in self.model_config.model.lower() and \
"awq" not in self.model_config.model.lower():
"awq" not in self.model_config.model.lower() and \
"qwen3" not in self.model_config.model.lower():
self.model.apply(padding_mlp)
from ipex_llm import optimize_model
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)