vLLM: update vLLM XPU to 0.8.3 version (#13118)
vLLM: update vLLM XPU to 0.8.3 version
This commit is contained in:
parent
f66eee1d1d
commit
51b41faad7
11 changed files with 4608 additions and 28217 deletions
|
|
@ -1,7 +1,9 @@
|
|||
From d345631f78a2f33ff1ddd7d9908b288eb0afaf46 Mon Sep 17 00:00:00 2001
|
||||
From: Huajun Li <huajun.li@.com>
|
||||
Date: Fri, 24 May 2024 09:47:26 +0800
|
||||
Subject: [PATCH 1/3] allreduce optimization with LL256 for Arc770 dGPU
|
||||
From dfe1851b59df6859829b447353307b7c916ccee0 Mon Sep 17 00:00:00 2001
|
||||
From: junhansh <junhan.shi@intel.com>
|
||||
Date: Mon, 28 Apr 2025 23:33:11 +0800
|
||||
Subject: [PATCH] oneccl for Arc770 V2025.0.0.6.7
|
||||
|
||||
allreduce optimization with LL256 for Arc770 dGPU
|
||||
|
||||
To enable this feature, please set env var:
|
||||
export CCL_DG2_ALLREDUCE=1
|
||||
|
|
@ -12,6 +14,15 @@ Build:
|
|||
3. cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp -DCMAKE_BUILD_TYPE=MinSizeRel
|
||||
4. ninja
|
||||
5. ls -al src/libccl*
|
||||
|
||||
Changes:
|
||||
optimize req_workgroup calculate
|
||||
|
||||
Revert "optimize req_workgroup calculate" for hang issue
|
||||
|
||||
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
|
||||
|
||||
fix_fdset_buffer_overflow_issue
|
||||
---
|
||||
src/CMakeLists.txt | 2 +
|
||||
src/coll/coll.cpp | 30 +-
|
||||
|
|
@ -20,9 +31,9 @@ Build:
|
|||
src/common/env/env.cpp | 1 +
|
||||
src/common/env/env.hpp | 1 +
|
||||
src/common/env/vars.hpp | 1 +
|
||||
src/dg2/dg2_allreduce.cpp | 642 +++++++++++++++++++++++++++++++
|
||||
src/dg2/dg2_allreduce.cpp | 640 +++++++++++++++++++++++++++++++
|
||||
src/dg2/dg2_allreduce.hpp | 13 +
|
||||
9 files changed, 693 insertions(+), 3 deletions(-)
|
||||
9 files changed, 691 insertions(+), 3 deletions(-)
|
||||
create mode 100644 src/dg2/dg2_allreduce.cpp
|
||||
create mode 100644 src/dg2/dg2_allreduce.hpp
|
||||
|
||||
|
|
@ -163,10 +174,10 @@ index 73dcf77..84ab518 100644
|
|||
constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE";
|
||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
||||
new file mode 100644
|
||||
index 0000000..15ace74
|
||||
index 0000000..73e114b
|
||||
--- /dev/null
|
||||
+++ b/src/dg2/dg2_allreduce.cpp
|
||||
@@ -0,0 +1,642 @@
|
||||
@@ -0,0 +1,640 @@
|
||||
+#include <fcntl.h>
|
||||
+#include <unistd.h>
|
||||
+#include <sys/un.h>
|
||||
|
|
@ -178,7 +189,7 @@ index 0000000..15ace74
|
|||
+#include <drm/drm.h>
|
||||
+
|
||||
+#include <mpi.h>
|
||||
+
|
||||
+#include <poll.h>
|
||||
+#include <vector>
|
||||
+#include <sstream>
|
||||
+#include <iostream>
|
||||
|
|
@ -315,7 +326,6 @@ index 0000000..15ace74
|
|||
+
|
||||
+static void *thread_func(void *arg)
|
||||
+{
|
||||
+ fd_set fds;
|
||||
+ int count = 0;
|
||||
+ char sock_path[64];
|
||||
+ int peer_buf_fd = 0;
|
||||
|
|
@ -323,6 +333,10 @@ index 0000000..15ace74
|
|||
+
|
||||
+ snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, rank, 0xa770);
|
||||
+ int srv_fd = srv_sock(sock_path);
|
||||
+ if (srv_fd < 0) {
|
||||
+ perror("srv_sock failed");
|
||||
+ return nullptr;
|
||||
+ }
|
||||
+
|
||||
+ //std::cout << "-----> srv_fd of " << sock_path << " : " << srv_fd << "\n";
|
||||
+
|
||||
|
|
@ -331,35 +345,30 @@ index 0000000..15ace74
|
|||
+ ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
|
||||
+ ze_device_handle_t ze_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
|
||||
+
|
||||
+ FD_ZERO(&fds);
|
||||
+ FD_SET(srv_fd, &fds);
|
||||
+ struct pollfd pfd = {
|
||||
+ .fd = srv_fd,
|
||||
+ .events = POLL_IN,
|
||||
+ .revents = 0
|
||||
+ };
|
||||
+ while (++count < world_size) {
|
||||
+ int ret = select(srv_fd + 1, &fds, NULL, NULL, NULL);
|
||||
+ switch (ret) {
|
||||
+ case 1:
|
||||
+ {
|
||||
+ int peer_rank;
|
||||
+ void *peer_buf;
|
||||
+ int ret = poll(&pfd, 1, -1);
|
||||
+ if (ret <= 0) {
|
||||
+ std::cerr << "poll failed: " << strerror(errno) << "\n";
|
||||
+ break;
|
||||
+ }
|
||||
+
|
||||
+ int conn_fd = accept(srv_fd, NULL, 0);
|
||||
+ ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
|
||||
+ if (pfd.revents & POLL_IN) {
|
||||
+ int peer_rank;
|
||||
+ void *peer_buf = nullptr;
|
||||
+
|
||||
+ ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
|
||||
+ zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED /* cached allocation */, &peer_buf);
|
||||
+ int conn_fd = accept(srv_fd, NULL, 0);
|
||||
+ ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
|
||||
+ ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
|
||||
+ zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED, &peer_buf);
|
||||
+
|
||||
+ peer_bufs[peer_rank] = peer_buf;
|
||||
+ //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
|
||||
+
|
||||
+ if (conn_fd > 0) close(conn_fd);
|
||||
+
|
||||
+ break;
|
||||
+ }
|
||||
+ case 0:
|
||||
+ case -1:
|
||||
+ std::cout << "srv_fd select() failed" << "\n";
|
||||
+ break;
|
||||
+ default:
|
||||
+ break;
|
||||
+ peer_bufs[peer_rank] = peer_buf;
|
||||
+ //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
|
||||
+ if (conn_fd > 0) close(conn_fd);
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
|
|
@ -831,105 +840,3 @@ index 0000000..0506445
|
|||
--
|
||||
2.34.1
|
||||
|
||||
|
||||
From 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa Mon Sep 17 00:00:00 2001
|
||||
From: Huajun Li <huajun.li@.com>
|
||||
Date: Fri, 7 Mar 2025 15:15:35 +0800
|
||||
Subject: [PATCH 2/3] optimize req_workgroup calculate
|
||||
|
||||
---
|
||||
src/dg2/dg2_allreduce.cpp | 25 ++-----------------------
|
||||
1 file changed, 2 insertions(+), 23 deletions(-)
|
||||
|
||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
||||
index 15ace74..83270ae 100644
|
||||
--- a/src/dg2/dg2_allreduce.cpp
|
||||
+++ b/src/dg2/dg2_allreduce.cpp
|
||||
@@ -527,30 +527,9 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
|
||||
auto chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
||||
auto chunk_with_pattern = sg_sz * LS_SZ; /* aligned to 256B */
|
||||
|
||||
- /* items will be assigned to each rank */
|
||||
- auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
|
||||
- auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
|
||||
- auto req_subgroups = 0;
|
||||
-
|
||||
- if (req_workgroups >= g_sz/l_sz) {
|
||||
- req_workgroups = g_sz/l_sz;
|
||||
- } else {
|
||||
- if (group_id == (req_workgroups - 1)) {
|
||||
- req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
|
||||
-
|
||||
- /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
|
||||
- /* Note: req_subgroups % (l_sz/sg_sz) might be 0 */
|
||||
- if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
|
||||
- if ((per_rank_items % (sg_sz - 1)) != 0) {
|
||||
- /* FIXME: */
|
||||
- req_workitems = per_rank_items % (sg_sz - 1);
|
||||
- chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
||||
- }
|
||||
- }
|
||||
- }
|
||||
- }
|
||||
+ auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
|
||||
|
||||
- if (group_id < req_workgroups) {
|
||||
+ if (work_left > 0) {
|
||||
// step 1: push data to next GPU
|
||||
{
|
||||
offset = base + local_world_rank * chunk_sz;
|
||||
--
|
||||
2.34.1
|
||||
|
||||
|
||||
From 1c58cc9ede5ca38138a270f9e5ff59bca74f51d4 Mon Sep 17 00:00:00 2001
|
||||
From: Huajun Li <huajun.li@.com>
|
||||
Date: Wed, 12 Mar 2025 13:29:27 +0800
|
||||
Subject: [PATCH 3/3] Revert "optimize req_workgroup calculate" for hang issue
|
||||
|
||||
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
|
||||
---
|
||||
src/dg2/dg2_allreduce.cpp | 25 +++++++++++++++++++++++--
|
||||
1 file changed, 23 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
||||
index 83270ae..15ace74 100644
|
||||
--- a/src/dg2/dg2_allreduce.cpp
|
||||
+++ b/src/dg2/dg2_allreduce.cpp
|
||||
@@ -527,9 +527,30 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
|
||||
auto chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
||||
auto chunk_with_pattern = sg_sz * LS_SZ; /* aligned to 256B */
|
||||
|
||||
- auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
|
||||
+ /* items will be assigned to each rank */
|
||||
+ auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
|
||||
+ auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
|
||||
+ auto req_subgroups = 0;
|
||||
+
|
||||
+ if (req_workgroups >= g_sz/l_sz) {
|
||||
+ req_workgroups = g_sz/l_sz;
|
||||
+ } else {
|
||||
+ if (group_id == (req_workgroups - 1)) {
|
||||
+ req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
|
||||
+
|
||||
+ /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
|
||||
+ /* Note: req_subgroups % (l_sz/sg_sz) might be 0 */
|
||||
+ if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
|
||||
+ if ((per_rank_items % (sg_sz - 1)) != 0) {
|
||||
+ /* FIXME: */
|
||||
+ req_workitems = per_rank_items % (sg_sz - 1);
|
||||
+ chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
+ }
|
||||
|
||||
- if (work_left > 0) {
|
||||
+ if (group_id < req_workgroups) {
|
||||
// step 1: push data to next GPU
|
||||
{
|
||||
offset = base + local_world_rank * chunk_sz;
|
||||
--
|
||||
2.34.1
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -54,19 +54,20 @@ RUN set -eux && \
|
|||
#
|
||||
# Install Intel PyTorch extension for LLM inference
|
||||
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||
#
|
||||
# Build torch-ccl
|
||||
mkdir -p /build && \
|
||||
cd /build && \
|
||||
git clone https://github.com/intel/torch-ccl.git && \
|
||||
cd torch-ccl && \
|
||||
git checkout ccl_torch2.5.0+xpu && \
|
||||
git checkout ccl_torch2.6.0+xpu && \
|
||||
git submodule sync && \
|
||||
git submodule update --init --recursive && \
|
||||
# This patch will enable build torch-ccl with pytorch 2.6 environment
|
||||
git apply /tmp/ccl_torch.patch && \
|
||||
# git apply /tmp/ccl_torch.patch && \
|
||||
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel && \
|
||||
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
|
||||
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl
|
||||
# Build oneCCL
|
||||
pip install ninja && \
|
||||
cd /build/ && \
|
||||
|
|
@ -85,7 +86,7 @@ RUN set -eux && \
|
|||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
|
||||
|
||||
# Copy the built torch-ccl package from the build stage
|
||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
|
||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
|
||||
COPY --from=build /llm/ /llm/
|
||||
COPY --from=build /build/oneCCL/build/src/libccl.so.1.0 /opt/intel/1ccl-wks/lib/
|
||||
COPY --from=build /build/oneCCL/build/src/libccl.so.1 /opt/intel/1ccl-wks/lib/
|
||||
|
|
@ -144,9 +145,10 @@ RUN set -eux && \
|
|||
# Install vllm dependencies
|
||||
pip install --upgrade fastapi && \
|
||||
pip install --upgrade "uvicorn[standard]" && \
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||
#
|
||||
# Install torch-ccl
|
||||
pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
|
||||
pip install /opt/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl && \
|
||||
#
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
|
||||
|
|
@ -168,21 +170,19 @@ RUN set -eux && \
|
|||
mkdir -p /llm && \
|
||||
cd /llm && \
|
||||
rm -rf /tmp/neo && \
|
||||
# Install intel_extension_for_pytorch
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \
|
||||
pip uninstall -y oneccl oneccl-devel && \
|
||||
pip install intel-opencl-rt==2025.0.2 intel-openmp==2025.0.2 && \
|
||||
#
|
||||
# Install vllm
|
||||
git clone -b v0.6.6.post1 https://github.com/vllm-project/vllm /llm/vllm && \
|
||||
git clone -b v0.8.3 https://github.com/vllm-project/vllm /llm/vllm && \
|
||||
cd /llm/vllm && \
|
||||
git apply /llm/vllm_for_multi_arc.patch && \
|
||||
pip install setuptools-scm && \
|
||||
pip install setuptools-scm==8.2.0 setuptools==78.1.0 && \
|
||||
pip install --upgrade cmake && \
|
||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
|
||||
pip install -v -r requirements/xpu.txt && \
|
||||
VLLM_TARGET_DEVICE=xpu python setup.py install && \
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||
pip uninstall -y oneccl oneccl-devel && \
|
||||
rm -rf /llm/vllm_for_multi_arc.patch && \
|
||||
pip install mpi4py fastapi uvicorn openai && \
|
||||
pip install ray
|
||||
pip install ray numba
|
||||
|
||||
|
||||
WORKDIR /llm/
|
||||
|
|
|
|||
|
|
@ -32,6 +32,9 @@ export TORCH_LLM_ALLREDUCE=0
|
|||
export CCL_SAME_STREAM=1
|
||||
export CCL_BLOCKING_WAIT=0
|
||||
|
||||
export VLLM_USE_V1=0
|
||||
export IPEX_LLM_LOWBIT=$LOAD_IN_LOW_BIT
|
||||
|
||||
source /opt/intel/1ccl-wks/setvars.sh
|
||||
|
||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -782,6 +782,9 @@ export USE_XETLA=OFF
|
|||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||
export TORCH_LLM_ALLREDUCE=0
|
||||
|
||||
export VLLM_USE_V1=0
|
||||
export IPEX_LLM_LOWBIT=fp8
|
||||
|
||||
source /opt/intel/1ccl-wks/setvars.sh
|
||||
|
||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||
|
|
@ -793,7 +796,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
|||
--device xpu \
|
||||
--dtype float16 \
|
||||
--enforce-eager \
|
||||
--load-in-low-bit fp8 \
|
||||
--load-in-low-bit $IPEX_LLM_LOWBIT \
|
||||
--max-model-len 2048 \
|
||||
--max-num-batched-tokens 4000 \
|
||||
--api-key <your-api-key> \
|
||||
|
|
|
|||
|
|
@ -50,9 +50,14 @@ pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorc
|
|||
pip install setuptools-scm
|
||||
pip install --upgrade cmake
|
||||
# cd to your workdir
|
||||
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
|
||||
git clone -b 0.8.3 https://github.com/analytics-zoo/vllm.git
|
||||
cd vllm
|
||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
|
||||
pip install setuptools-scm==8.2.0 setuptools==78.1.0
|
||||
pip install --upgrade cmake
|
||||
pip install -v -r requirements/xpu.txt
|
||||
VLLM_TARGET_DEVICE=xpu python setup.py install
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||
pip uninstall -y oneccl oneccl-devel
|
||||
# For Qwen model support
|
||||
pip install transformers_stream_generator einops tiktoken
|
||||
pip install ray
|
||||
|
|
@ -93,6 +98,8 @@ For vLLM, you can start the service using the following command:
|
|||
model="YOUR_MODEL_PATH"
|
||||
served_model_name="YOUR_MODEL_NAME"
|
||||
export VLLM_RPC_TIMEOUT=100000
|
||||
export VLLM_USE_V1=0
|
||||
export IPEX_LLM_LOWBIT=fp8
|
||||
|
||||
# You may need to adjust the value of
|
||||
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
|
||||
|
|
@ -107,7 +114,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
|||
--device xpu \
|
||||
--dtype float16 \
|
||||
--enforce-eager \
|
||||
--load-in-low-bit sym_int4 \
|
||||
--load-in-low-bit $IPEX_LLM_LOWBIT \
|
||||
--max-model-len 4096 \
|
||||
--max-num-batched-tokens 10240 \
|
||||
--max-num-seqs 12 \
|
||||
|
|
|
|||
|
|
@ -150,12 +150,13 @@ def is_linear_module(module):
|
|||
if _VLLM_VERSION is None:
|
||||
_VLLM_VERSION = get_package_version('vllm')
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
|
||||
MergedColumnParallelLinear, ReplicatedLinear
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
VLLM_LINEAR_LIST = [
|
||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
MergedColumnParallelLinear, ReplicatedLinear,
|
||||
]
|
||||
if 'xpu' in _VLLM_VERSION:
|
||||
VLLM_LINEAR_LIST.append(ParallelLMHead)
|
||||
|
|
|
|||
|
|
@ -13,10 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
|
||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine, IPEXLLMAsyncV1Engine, IPEXLLMLLMV1Engine
|
||||
__all__ = [
|
||||
"IPEXLLMAsyncLLMEngine",
|
||||
"IPEXLLMLLMEngine",
|
||||
"IPEXLLMClass",
|
||||
"IPEXLLMAsyncV1Engine",
|
||||
"IPEXLLMLLMV1Engine",
|
||||
"run_mp_engine",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ logger = init_logger(__name__)
|
|||
|
||||
|
||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||
_is_converted = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
@ -53,13 +55,39 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
|||
) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> "AsyncLLMEngine":
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
disable_log_requests=disable_log_requests,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
|
||||
|
||||
class IPEXLLMAsyncV1Engine(AsyncLLM):
|
||||
_is_converted = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -74,13 +102,39 @@ class IPEXLLMAsyncV1Engine(AsyncLLM):
|
|||
load_in_low_bit: str = "sym_int4",
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, # noqa
|
||||
) -> "AsyncLLM":
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> "AsyncLLM":
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
disable_log_requests=disable_log_requests,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
|
||||
|
||||
class IPEXLLMClass(LLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -94,20 +148,20 @@ class IPEXLLMClass(LLM):
|
|||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
seed: Optional[int] = None,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = True,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
|
||||
disable_async_output_proc: bool = False,
|
||||
hf_overrides: Optional[HfOverrides]=None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]]=None,
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
|
||||
compilation_config: Optional[Union[int, dict[str, Any]]]=None,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
|
@ -120,6 +174,13 @@ class IPEXLLMClass(LLM):
|
|||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
|
||||
if "worker_cls" in kwargs:
|
||||
worker_cls = kwargs["worker_cls"]
|
||||
# if the worker_cls is not qualified string name,
|
||||
# we serialize it using cloudpickle to avoid pickling issues
|
||||
if isinstance(worker_cls, type):
|
||||
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
|
||||
|
||||
if compilation_config is not None:
|
||||
if isinstance(compilation_config, (int, dict)):
|
||||
compilation_config_instance = CompilationConfig.from_cli(
|
||||
|
|
@ -159,11 +220,13 @@ class IPEXLLMClass(LLM):
|
|||
# Logic to switch between engines is done at runtime instead of import
|
||||
# to avoid import order issues
|
||||
self.engine_class = self.get_engine_class()
|
||||
# print("!!! ", load_in_low_bit)
|
||||
self.llm_engine = self.engine_class.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
|
||||
self.request_counter = Counter()
|
||||
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||
|
||||
@staticmethod
|
||||
def get_engine_class() -> Type[LLMEngine]:
|
||||
|
|
@ -173,6 +236,8 @@ class IPEXLLMClass(LLM):
|
|||
|
||||
|
||||
class IPEXLLMLLMV1Engine(V1LLMEngine):
|
||||
_is_converted = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
@ -188,14 +253,37 @@ class IPEXLLMLLMV1Engine(V1LLMEngine):
|
|||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_engine_args(engine_args,
|
||||
usage_context,
|
||||
stat_loggers,
|
||||
enable_multiprocessing)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||
disable_log_stats: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> "LLMEngine":
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
disable_log_stats=disable_log_stats
|
||||
)
|
||||
|
||||
|
||||
class IPEXLLMLLMEngine(LLMEngine):
|
||||
_is_converted = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
|
@ -209,33 +297,89 @@ class IPEXLLMLLMEngine(LLMEngine):
|
|||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_engine_args(engine_args, usage_context, stat_loggers)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||
disable_log_stats: bool = False,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
) -> "LLMEngine":
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
disable_log_stats=disable_log_stats
|
||||
)
|
||||
|
||||
|
||||
class IPEXLLMMQLLMEngine(MQLLMEngine):
|
||||
_is_converted = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_engine_args(engine_args, usage_context, ipc_path)
|
||||
|
||||
@classmethod
|
||||
def from_vllm_config(cls, vllm_config: VllmConfig,
|
||||
usage_context: UsageContext,
|
||||
disable_log_requests: bool, disable_log_stats: bool,
|
||||
ipc_path: str, load_in_low_bit: str) -> "MQLLMEngine":
|
||||
|
||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
||||
ipc_path: str, load_in_low_bit: str, engine_alive):
|
||||
if not cls._is_converted:
|
||||
_ipex_llm_convert(load_in_low_bit)
|
||||
cls._is_converted = True
|
||||
return super().from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
ipc_path=ipc_path,
|
||||
usage_context=usage_context,
|
||||
disable_log_requests=disable_log_requests,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
|
||||
|
||||
|
||||
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
||||
ipc_path: str, disable_log_stats: bool,
|
||||
disable_log_requests: bool, load_in_low_bit, engine_alive):
|
||||
try:
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
engine = IPEXLLMMQLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_log_requests=disable_log_requests,
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
|
||||
usage_context=usage_context,
|
||||
ipc_path=ipc_path,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
engine.start()
|
||||
|
||||
except BaseException as e:
|
||||
logger.exception(e)
|
||||
engine_alive.value = False
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import gc
|
||||
import importlib
|
||||
import inspect
|
||||
import multiprocessing
|
||||
|
|
@ -10,16 +13,18 @@ import socket
|
|||
import tempfile
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set, Tuple
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
|
@ -27,17 +32,17 @@ from typing_extensions import assert_never
|
|||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine # type: ignore
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from ipex_llm.vllm.xpu.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.chat_utils import (load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
resolve_mistral_chat_template)
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
|
||||
# from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
|
|
@ -46,33 +51,46 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
LoadLoRAAdapterRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import with_cancellation
|
||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||
with_cancellation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||
is_valid_ipv6_address, set_ulimit)
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
|
@ -83,7 +101,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
|||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
_running_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -102,6 +120,11 @@ async def lifespan(app: FastAPI):
|
|||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
|
@ -139,24 +162,49 @@ async def build_async_engine_client_from_engine_args(
|
|||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Fall back
|
||||
# TODO: fill out feature matrix.
|
||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
||||
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
|
||||
# Create the EngineConfig (determines if we can use V1).
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
|
||||
# V1 AsyncLLM.
|
||||
if envs.VLLM_USE_V1:
|
||||
if disable_frontend_multiprocessing:
|
||||
logger.warning(
|
||||
"V1 is enabled, but got --disable-frontend-multiprocessing. "
|
||||
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
|
||||
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncV1Engine as AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_requests=engine_args.disable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
yield async_llm
|
||||
finally:
|
||||
if async_llm:
|
||||
async_llm.shutdown()
|
||||
|
||||
# V0 AsyncLLM.
|
||||
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
|
||||
or disable_frontend_multiprocessing):
|
||||
|
||||
engine_client: Optional[EngineClient] = None
|
||||
try:
|
||||
# When starting this, we are actually starting with the V1Engine
|
||||
# Here we are doing a classification, we will need to do this in IPEX-LLM
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER,
|
||||
engine_client = AsyncLLMEngine.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
usage_context=usage_context,
|
||||
disable_log_requests=engine_args.disable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
yield engine_client
|
||||
finally:
|
||||
if engine_client and hasattr(engine_client, "shutdown"):
|
||||
engine_client.shutdown()
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
# V0MQLLMEngine.
|
||||
else:
|
||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||
# Make TemporaryDirectory for prometheus multiprocessing
|
||||
|
|
@ -183,14 +231,18 @@ async def build_async_engine_client_from_engine_args(
|
|||
# so we need to spawn a new process
|
||||
context = multiprocessing.get_context("spawn")
|
||||
|
||||
# Ensure we can serialize transformer config before spawning
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
# The Process can raise an exception during startup, which may
|
||||
# not actually result in an exitcode being reported. As a result
|
||||
# we use a shared variable to communicate the information.
|
||||
engine_alive = multiprocessing.Value('b', True, lock=False)
|
||||
engine_process = context.Process(target=run_mp_engine,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
ipc_path, load_in_low_bit, engine_alive))
|
||||
engine_process = context.Process(
|
||||
target=run_mp_engine,
|
||||
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
|
||||
engine_args.disable_log_stats,
|
||||
engine_args.disable_log_requests, load_in_low_bit, engine_alive))
|
||||
engine_process.start()
|
||||
engine_pid = engine_process.pid
|
||||
assert engine_pid is not None, "Engine process failed to start."
|
||||
|
|
@ -205,8 +257,7 @@ async def build_async_engine_client_from_engine_args(
|
|||
atexit.register(_cleanup_ipc_path)
|
||||
|
||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
|
||||
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
|
||||
engine_pid)
|
||||
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_client)
|
||||
|
|
@ -244,6 +295,43 @@ async def build_async_engine_client_from_engine_args(
|
|||
multiprocess.mark_process_dead(engine_process.pid)
|
||||
|
||||
|
||||
async def validate_json_request(raw_request: Request):
|
||||
content_type = raw_request.headers.get("content-type", "").lower()
|
||||
media_type = content_type.split(";", maxsplit=1)[0]
|
||||
if media_type != "application/json":
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||
detail="Unsupported Media Type: Only 'application/json' is allowed"
|
||||
)
|
||||
|
||||
|
||||
save_dict = {}
|
||||
import os
|
||||
flag = os.getenv("VLLM_LOG_OUTPUT", None)
|
||||
async def stream_generator(generator, request, request_id):
|
||||
async for chunk in generator:
|
||||
if request_id not in save_dict:
|
||||
save_dict[request_id] = ""
|
||||
import json
|
||||
try:
|
||||
data = chunk.strip()
|
||||
if data.startswith('data: '):
|
||||
data = data[len('data: '):]
|
||||
else:
|
||||
yield chunk
|
||||
json_data = json.loads(data)
|
||||
if 'choices' in json_data and len(json_data['choices']) > 0:
|
||||
choice = json_data['choices'][0]
|
||||
if 'delta' in choice:
|
||||
save_dict[request_id] += choice["delta"]["content"]
|
||||
elif 'text' in choice:
|
||||
save_dict[request_id] += choice["text"]
|
||||
except json.JSONDecodeError:
|
||||
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
|
||||
pass # Done
|
||||
yield chunk
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
|
@ -254,6 +342,7 @@ def mount_metrics(app: FastAPI):
|
|||
# See https://prometheus.github.io/client_python/multiprocess/
|
||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||
multiprocess)
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||
if prometheus_multiproc_dir_path is not None:
|
||||
|
|
@ -261,6 +350,16 @@ def mount_metrics(app: FastAPI):
|
|||
prometheus_multiproc_dir_path)
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
Instrumentator(
|
||||
excluded_handlers=[
|
||||
"/metrics",
|
||||
"/health",
|
||||
"/load",
|
||||
"/ping",
|
||||
"/version",
|
||||
],
|
||||
registry=registry,
|
||||
).add().instrument(app).expose(app)
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
|
@ -298,7 +397,11 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
|||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
def score(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
|
|
@ -306,6 +409,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
|
|||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
|
@ -317,7 +424,31 @@ async def health(raw_request: Request) -> Response:
|
|||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
@router.get("/load")
|
||||
async def get_server_load_metrics(request: Request):
|
||||
# This endpoint returns the current server load metrics.
|
||||
# It tracks requests utilizing the GPU from the following routes:
|
||||
# - /v1/chat/completions
|
||||
# - /v1/completions
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /score
|
||||
# - /v1/score
|
||||
# - /rerank
|
||||
# - /v1/rerank
|
||||
# - /v2/rerank
|
||||
return JSONResponse(
|
||||
content={'server_load': request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.api_route("/ping", methods=["GET", "POST"])
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
|
@ -332,7 +463,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
|||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
handler = tokenization(raw_request)
|
||||
|
|
@ -361,35 +492,10 @@ async def show_version():
|
|||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
save_dict = {}
|
||||
import os
|
||||
flag = os.getenv("VLLM_LOG_OUTPUT", None)
|
||||
async def stream_generator(generator, request, request_id):
|
||||
async for chunk in generator:
|
||||
if request_id not in save_dict:
|
||||
save_dict[request_id] = ""
|
||||
import json
|
||||
try:
|
||||
data = chunk.strip()
|
||||
if data.startswith('data: '):
|
||||
data = data[len('data: '):]
|
||||
else:
|
||||
yield chunk
|
||||
json_data = json.loads(data)
|
||||
if 'choices' in json_data and len(json_data['choices']) > 0:
|
||||
choice = json_data['choices'][0]
|
||||
if 'delta' in choice:
|
||||
save_dict[request_id] += choice["delta"]["content"]
|
||||
elif 'text' in choice:
|
||||
save_dict[request_id] += choice["text"]
|
||||
except json.JSONDecodeError:
|
||||
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
|
||||
pass # Done
|
||||
yield chunk
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/v1/chat/completions",
|
||||
dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
handler = chat(raw_request)
|
||||
|
|
@ -401,7 +507,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
request_id = "chatcmpl-" \
|
||||
f"{handler._base_request_id(raw_request, request.request_id)}"
|
||||
print(f"First received request_id: {request_id}, request: {request}")
|
||||
|
||||
generator = await handler.create_chat_completion(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
|
|
@ -418,8 +523,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
|||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
|
|
@ -438,14 +544,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
|||
if flag is not None:
|
||||
print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
|
||||
if flag is not None:
|
||||
return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
handler = embedding(raw_request)
|
||||
if handler is None:
|
||||
|
|
@ -460,6 +567,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|||
"use the Pooling API (`/pooling`) instead.")
|
||||
|
||||
res = await fallback_handler.create_pooling(request, raw_request)
|
||||
|
||||
generator: Union[ErrorResponse, EmbeddingResponse]
|
||||
if isinstance(res, PoolingResponse):
|
||||
generator = EmbeddingResponse(
|
||||
id=res.id,
|
||||
|
|
@ -488,8 +597,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
|||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/pooling")
|
||||
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
handler = pooling(raw_request)
|
||||
if handler is None:
|
||||
|
|
@ -506,8 +616,9 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
|||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/score")
|
||||
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
|
|
@ -524,8 +635,9 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
|||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/score")
|
||||
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||
|
|
@ -534,6 +646,160 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
|||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v1/audio/transcriptions")
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||
Form()],
|
||||
raw_request: Request):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Transcriptions API")
|
||||
|
||||
audio_data = await request.file.read()
|
||||
generator = await handler.create_transcription(audio_data, request,
|
||||
raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Rerank (Score) API")
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning_once(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client "
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
"default": (CompletionRequest, create_completion),
|
||||
},
|
||||
"embed": {
|
||||
"messages": (EmbeddingChatRequest, create_embedding),
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"rerank": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
"classify": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
}
|
||||
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
async def reset_prefix_cache(raw_request: Request):
|
||||
"""
|
||||
Reset the prefix cache. Note that we currently do not check if the
|
||||
prefix cache is successfully reset in the API server.
|
||||
"""
|
||||
device = None
|
||||
device_str = raw_request.query_params.get("device")
|
||||
if device_str is not None:
|
||||
device = Device[device_str.upper()]
|
||||
logger.info("Resetting prefix cache with specific %s...", str(device))
|
||||
await engine_client(raw_request).reset_prefix_cache(device)
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/sleep")
|
||||
async def sleep(raw_request: Request):
|
||||
# get POST params
|
||||
level = raw_request.query_params.get("level", "1")
|
||||
await engine_client(raw_request).sleep(int(level))
|
||||
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/wake_up")
|
||||
async def wake_up(raw_request: Request):
|
||||
tags = raw_request.query_params.getlist("tags")
|
||||
if tags == []:
|
||||
# set to None to wake up all tags if no tags are provided
|
||||
tags = None
|
||||
logger.info("wake up the engine with tags: %s", tags)
|
||||
await engine_client(raw_request).wake_up(tags)
|
||||
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||
# is sent but does not finish yet when we return a response.
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.get("/is_sleeping")
|
||||
async def is_sleeping(raw_request: Request):
|
||||
logger.info("check whether the engine is sleeping")
|
||||
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||
|
||||
|
||||
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
|
||||
async def invocations(raw_request: Request):
|
||||
"""
|
||||
For SageMaker, routes requests to other handlers based on model `task`.
|
||||
"""
|
||||
body = await raw_request.json()
|
||||
task = raw_request.app.state.task
|
||||
|
||||
if task not in TASK_HANDLERS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
||||
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
||||
|
||||
handler_config = TASK_HANDLERS[task]
|
||||
if "messages" in body:
|
||||
request_model, handler = handler_config["messages"]
|
||||
else:
|
||||
request_model, handler = handler_config["default"]
|
||||
|
||||
# this is required since we lose the FastAPI automatic casting
|
||||
request = request_model.model_validate(body)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
|
|
@ -556,32 +822,30 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
|||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"Lora dynamic loading & unloading is enabled in the API server. "
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
@router.post("/v1/load_lora_adapter",
|
||||
dependencies=[Depends(validate_json_request)])
|
||||
async def load_lora_adapter(request: LoadLoRAAdapterRequest,
|
||||
raw_request: Request):
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
handler = models(raw_request)
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
@router.post("/v1/unload_lora_adapter",
|
||||
dependencies=[Depends(validate_json_request)])
|
||||
async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
|
||||
raw_request: Request):
|
||||
for route in [chat, completion, embedding]:
|
||||
handler = route(raw_request)
|
||||
if handler is not None:
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
handler = models(raw_request)
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
|
@ -615,7 +879,8 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
if token := envs.VLLM_API_KEY or args.api_key:
|
||||
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
||||
if token := args.api_key or envs.VLLM_API_KEY:
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
|
|
@ -644,11 +909,26 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
response.headers["X-Request-Id"] = request_id
|
||||
return response
|
||||
|
||||
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
||||
logger.warning("CAUTION: Enabling log response in the API Server. "
|
||||
"This can include sensitive information and should be "
|
||||
"avoided in production.")
|
||||
|
||||
@app.middleware("http")
|
||||
async def log_response(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
response_body = [
|
||||
section async for section in response.body_iterator
|
||||
]
|
||||
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||
logger.info("response_body={%s}", response_body[0].decode())
|
||||
return response
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
if inspect.isclass(imported):
|
||||
app.add_middleware(imported)
|
||||
app.add_middleware(imported) # type: ignore[arg-type]
|
||||
elif inspect.iscoroutinefunction(imported):
|
||||
app.middleware("http")(imported)
|
||||
else:
|
||||
|
|
@ -658,7 +938,7 @@ def build_app(args: Namespace) -> FastAPI:
|
|||
return app
|
||||
|
||||
|
||||
def init_app_state(
|
||||
async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
|
|
@ -683,15 +963,36 @@ def init_app_state(
|
|||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
# Get the tokenizer to check official template
|
||||
tokenizer = await engine_client.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# The warning is logged in resolve_mistral_chat_template.
|
||||
resolved_chat_template = resolve_mistral_chat_template(
|
||||
chat_template=resolved_chat_template)
|
||||
else:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
if hf_chat_template != resolved_chat_template:
|
||||
logger.warning(
|
||||
"Using supplied chat template: %s\n"
|
||||
"It is different from official chat template '%s'. "
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template, args.model)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
# TODO: The chat template is now broken for lora adapters :(
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
|
@ -703,6 +1004,8 @@ def init_app_state(
|
|||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
enable_reasoning=args.enable_reasoning,
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
|
|
@ -728,7 +1031,13 @@ def init_app_state(
|
|||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embed" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
state.openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger) if model_config.task in (
|
||||
"score", "embed", "pooling") else None
|
||||
state.jinaai_serving_reranking = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
|
|
@ -742,92 +1051,26 @@ def init_app_state(
|
|||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
)
|
||||
state.openai_serving_transcription = OpenAIServingTranscription(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.runner_type == "transcription" else None
|
||||
state.task = model_config.task
|
||||
# if args.served_model_name is not None:
|
||||
# served_model_names = args.served_model_name
|
||||
# else:
|
||||
# served_model_names = [args.model]
|
||||
|
||||
# if args.disable_log_requests:
|
||||
# request_logger = None
|
||||
# else:
|
||||
# request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# base_model_paths = [
|
||||
# BaseModelPath(name=name, model_path=args.model)
|
||||
# for name in served_model_names
|
||||
# ]
|
||||
|
||||
# state.engine_client = engine_client
|
||||
# state.log_stats = not args.disable_log_stats
|
||||
|
||||
# resolved_chat_template = load_chat_template(args.chat_template)
|
||||
# logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
|
||||
# state.openai_serving_chat = OpenAIServingChat(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# args.response_role,
|
||||
# lora_modules=args.lora_modules,
|
||||
# prompt_adapters=args.prompt_adapters,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
# enable_auto_tools=args.enable_auto_tool_choice,
|
||||
# tool_parser=args.tool_call_parser,
|
||||
# enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
# ) if model_config.runner_type == "generate" else None
|
||||
# state.openai_serving_completion = OpenAIServingCompletion(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# lora_modules=args.lora_modules,
|
||||
# prompt_adapters=args.prompt_adapters,
|
||||
# request_logger=request_logger,
|
||||
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
# ) if model_config.runner_type == "generate" else None
|
||||
# state.openai_serving_pooling = OpenAIServingPooling(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# ) if model_config.runner_type == "pooling" else None
|
||||
# state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# ) if model_config.task == "embed" else None
|
||||
# state.openai_serving_scores = OpenAIServingScores(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# request_logger=request_logger
|
||||
# ) if model_config.task == "score" else None
|
||||
# state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
# engine_client,
|
||||
# model_config,
|
||||
# base_model_paths,
|
||||
# lora_modules=args.lora_modules,
|
||||
# request_logger=request_logger,
|
||||
# chat_template=resolved_chat_template,
|
||||
# chat_template_content_format=args.chat_template_content_format,
|
||||
# )
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
|
||||
|
||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
||||
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||
family = socket.AF_INET
|
||||
if is_valid_ipv6_address(addr[0]):
|
||||
family = socket.AF_INET6
|
||||
|
||||
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
sock.bind(addr)
|
||||
|
||||
return sock
|
||||
|
|
@ -840,11 +1083,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
and args.tool_call_parser not in valid_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
||||
|
||||
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
|
||||
if args.enable_reasoning \
|
||||
and args.reasoning_parser not in valid_reasoning_parses:
|
||||
raise KeyError(
|
||||
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
|
|
@ -866,13 +1116,28 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||
app = build_app(args)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
await init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
def _listen_addr(a: str) -> str:
|
||||
if is_valid_ipv6_address(a):
|
||||
return '[' + a + ']'
|
||||
return a or "0.0.0.0"
|
||||
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
logger.info("Starting vLLM API server on http%s://%s:%d",
|
||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||
sock_addr[1])
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
# NOTE: When the 'disable_uvicorn_access_log' value is True,
|
||||
# no access log will be output.
|
||||
access_log=not args.disable_uvicorn_access_log,
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
|
|
@ -882,16 +1147,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
sock.close()
|
||||
try:
|
||||
await shutdown_task
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
||||
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
|
||||
# entrypoints.
|
||||
logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
|
||||
"instead of `vllm.entrypoints.openai.api_server` to start the API server")
|
||||
cli_env_setup()
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser = make_arg_parser(parser)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ def _sample_get_logits(
|
|||
logits = lm_head(hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
if self.use_gather:
|
||||
if self.use_all_gather:
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
else:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
|
|
@ -63,6 +63,8 @@ def _model_sample_convert():
|
|||
|
||||
|
||||
def _ipex_llm_convert(load_in_low_bit):
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
|
||||
from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
|
||||
|
|
@ -99,7 +101,8 @@ def get_load_function(low_bit):
|
|||
"codegeex4-all" in self.vllm_config.model_config.model.lower() or
|
||||
"chatglm" in self.vllm_config.model_config.model.lower()) and \
|
||||
"gptq" not in self.model_config.model.lower() and \
|
||||
"awq" not in self.model_config.model.lower():
|
||||
"awq" not in self.model_config.model.lower() and \
|
||||
"qwen3" not in self.model_config.model.lower():
|
||||
self.model.apply(padding_mlp)
|
||||
from ipex_llm import optimize_model
|
||||
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
||||
|
|
|
|||
Loading…
Reference in a new issue