vLLM: update vLLM XPU to 0.8.3 version (#13118)
vLLM: update vLLM XPU to 0.8.3 version
This commit is contained in:
parent
f66eee1d1d
commit
51b41faad7
11 changed files with 4608 additions and 28217 deletions
|
|
@ -1,7 +1,9 @@
|
||||||
From d345631f78a2f33ff1ddd7d9908b288eb0afaf46 Mon Sep 17 00:00:00 2001
|
From dfe1851b59df6859829b447353307b7c916ccee0 Mon Sep 17 00:00:00 2001
|
||||||
From: Huajun Li <huajun.li@.com>
|
From: junhansh <junhan.shi@intel.com>
|
||||||
Date: Fri, 24 May 2024 09:47:26 +0800
|
Date: Mon, 28 Apr 2025 23:33:11 +0800
|
||||||
Subject: [PATCH 1/3] allreduce optimization with LL256 for Arc770 dGPU
|
Subject: [PATCH] oneccl for Arc770 V2025.0.0.6.7
|
||||||
|
|
||||||
|
allreduce optimization with LL256 for Arc770 dGPU
|
||||||
|
|
||||||
To enable this feature, please set env var:
|
To enable this feature, please set env var:
|
||||||
export CCL_DG2_ALLREDUCE=1
|
export CCL_DG2_ALLREDUCE=1
|
||||||
|
|
@ -12,6 +14,15 @@ Build:
|
||||||
3. cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp -DCMAKE_BUILD_TYPE=MinSizeRel
|
3. cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp -DCMAKE_BUILD_TYPE=MinSizeRel
|
||||||
4. ninja
|
4. ninja
|
||||||
5. ls -al src/libccl*
|
5. ls -al src/libccl*
|
||||||
|
|
||||||
|
Changes:
|
||||||
|
optimize req_workgroup calculate
|
||||||
|
|
||||||
|
Revert "optimize req_workgroup calculate" for hang issue
|
||||||
|
|
||||||
|
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
|
||||||
|
|
||||||
|
fix_fdset_buffer_overflow_issue
|
||||||
---
|
---
|
||||||
src/CMakeLists.txt | 2 +
|
src/CMakeLists.txt | 2 +
|
||||||
src/coll/coll.cpp | 30 +-
|
src/coll/coll.cpp | 30 +-
|
||||||
|
|
@ -20,9 +31,9 @@ Build:
|
||||||
src/common/env/env.cpp | 1 +
|
src/common/env/env.cpp | 1 +
|
||||||
src/common/env/env.hpp | 1 +
|
src/common/env/env.hpp | 1 +
|
||||||
src/common/env/vars.hpp | 1 +
|
src/common/env/vars.hpp | 1 +
|
||||||
src/dg2/dg2_allreduce.cpp | 642 +++++++++++++++++++++++++++++++
|
src/dg2/dg2_allreduce.cpp | 640 +++++++++++++++++++++++++++++++
|
||||||
src/dg2/dg2_allreduce.hpp | 13 +
|
src/dg2/dg2_allreduce.hpp | 13 +
|
||||||
9 files changed, 693 insertions(+), 3 deletions(-)
|
9 files changed, 691 insertions(+), 3 deletions(-)
|
||||||
create mode 100644 src/dg2/dg2_allreduce.cpp
|
create mode 100644 src/dg2/dg2_allreduce.cpp
|
||||||
create mode 100644 src/dg2/dg2_allreduce.hpp
|
create mode 100644 src/dg2/dg2_allreduce.hpp
|
||||||
|
|
||||||
|
|
@ -163,10 +174,10 @@ index 73dcf77..84ab518 100644
|
||||||
constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE";
|
constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE";
|
||||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
||||||
new file mode 100644
|
new file mode 100644
|
||||||
index 0000000..15ace74
|
index 0000000..73e114b
|
||||||
--- /dev/null
|
--- /dev/null
|
||||||
+++ b/src/dg2/dg2_allreduce.cpp
|
+++ b/src/dg2/dg2_allreduce.cpp
|
||||||
@@ -0,0 +1,642 @@
|
@@ -0,0 +1,640 @@
|
||||||
+#include <fcntl.h>
|
+#include <fcntl.h>
|
||||||
+#include <unistd.h>
|
+#include <unistd.h>
|
||||||
+#include <sys/un.h>
|
+#include <sys/un.h>
|
||||||
|
|
@ -178,7 +189,7 @@ index 0000000..15ace74
|
||||||
+#include <drm/drm.h>
|
+#include <drm/drm.h>
|
||||||
+
|
+
|
||||||
+#include <mpi.h>
|
+#include <mpi.h>
|
||||||
+
|
+#include <poll.h>
|
||||||
+#include <vector>
|
+#include <vector>
|
||||||
+#include <sstream>
|
+#include <sstream>
|
||||||
+#include <iostream>
|
+#include <iostream>
|
||||||
|
|
@ -315,7 +326,6 @@ index 0000000..15ace74
|
||||||
+
|
+
|
||||||
+static void *thread_func(void *arg)
|
+static void *thread_func(void *arg)
|
||||||
+{
|
+{
|
||||||
+ fd_set fds;
|
|
||||||
+ int count = 0;
|
+ int count = 0;
|
||||||
+ char sock_path[64];
|
+ char sock_path[64];
|
||||||
+ int peer_buf_fd = 0;
|
+ int peer_buf_fd = 0;
|
||||||
|
|
@ -323,6 +333,10 @@ index 0000000..15ace74
|
||||||
+
|
+
|
||||||
+ snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, rank, 0xa770);
|
+ snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, rank, 0xa770);
|
||||||
+ int srv_fd = srv_sock(sock_path);
|
+ int srv_fd = srv_sock(sock_path);
|
||||||
|
+ if (srv_fd < 0) {
|
||||||
|
+ perror("srv_sock failed");
|
||||||
|
+ return nullptr;
|
||||||
|
+ }
|
||||||
+
|
+
|
||||||
+ //std::cout << "-----> srv_fd of " << sock_path << " : " << srv_fd << "\n";
|
+ //std::cout << "-----> srv_fd of " << sock_path << " : " << srv_fd << "\n";
|
||||||
+
|
+
|
||||||
|
|
@ -331,35 +345,30 @@ index 0000000..15ace74
|
||||||
+ ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
|
+ ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
|
||||||
+ ze_device_handle_t ze_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
|
+ ze_device_handle_t ze_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
|
||||||
+
|
+
|
||||||
+ FD_ZERO(&fds);
|
+ struct pollfd pfd = {
|
||||||
+ FD_SET(srv_fd, &fds);
|
+ .fd = srv_fd,
|
||||||
|
+ .events = POLL_IN,
|
||||||
|
+ .revents = 0
|
||||||
|
+ };
|
||||||
+ while (++count < world_size) {
|
+ while (++count < world_size) {
|
||||||
+ int ret = select(srv_fd + 1, &fds, NULL, NULL, NULL);
|
+ int ret = poll(&pfd, 1, -1);
|
||||||
+ switch (ret) {
|
+ if (ret <= 0) {
|
||||||
+ case 1:
|
+ std::cerr << "poll failed: " << strerror(errno) << "\n";
|
||||||
+ {
|
+ break;
|
||||||
+ int peer_rank;
|
+ }
|
||||||
+ void *peer_buf;
|
|
||||||
+
|
+
|
||||||
+ int conn_fd = accept(srv_fd, NULL, 0);
|
+ if (pfd.revents & POLL_IN) {
|
||||||
+ ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
|
+ int peer_rank;
|
||||||
|
+ void *peer_buf = nullptr;
|
||||||
+
|
+
|
||||||
+ ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
|
+ int conn_fd = accept(srv_fd, NULL, 0);
|
||||||
+ zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED /* cached allocation */, &peer_buf);
|
+ ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
|
||||||
|
+ ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
|
||||||
|
+ zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED, &peer_buf);
|
||||||
+
|
+
|
||||||
+ peer_bufs[peer_rank] = peer_buf;
|
+ peer_bufs[peer_rank] = peer_buf;
|
||||||
+ //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
|
+ //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
|
||||||
+
|
+ if (conn_fd > 0) close(conn_fd);
|
||||||
+ if (conn_fd > 0) close(conn_fd);
|
|
||||||
+
|
|
||||||
+ break;
|
|
||||||
+ }
|
|
||||||
+ case 0:
|
|
||||||
+ case -1:
|
|
||||||
+ std::cout << "srv_fd select() failed" << "\n";
|
|
||||||
+ break;
|
|
||||||
+ default:
|
|
||||||
+ break;
|
|
||||||
+ }
|
+ }
|
||||||
+ }
|
+ }
|
||||||
+
|
+
|
||||||
|
|
@ -831,105 +840,3 @@ index 0000000..0506445
|
||||||
--
|
--
|
||||||
2.34.1
|
2.34.1
|
||||||
|
|
||||||
|
|
||||||
From 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa Mon Sep 17 00:00:00 2001
|
|
||||||
From: Huajun Li <huajun.li@.com>
|
|
||||||
Date: Fri, 7 Mar 2025 15:15:35 +0800
|
|
||||||
Subject: [PATCH 2/3] optimize req_workgroup calculate
|
|
||||||
|
|
||||||
---
|
|
||||||
src/dg2/dg2_allreduce.cpp | 25 ++-----------------------
|
|
||||||
1 file changed, 2 insertions(+), 23 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
|
||||||
index 15ace74..83270ae 100644
|
|
||||||
--- a/src/dg2/dg2_allreduce.cpp
|
|
||||||
+++ b/src/dg2/dg2_allreduce.cpp
|
|
||||||
@@ -527,30 +527,9 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
|
|
||||||
auto chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
|
||||||
auto chunk_with_pattern = sg_sz * LS_SZ; /* aligned to 256B */
|
|
||||||
|
|
||||||
- /* items will be assigned to each rank */
|
|
||||||
- auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
|
|
||||||
- auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
|
|
||||||
- auto req_subgroups = 0;
|
|
||||||
-
|
|
||||||
- if (req_workgroups >= g_sz/l_sz) {
|
|
||||||
- req_workgroups = g_sz/l_sz;
|
|
||||||
- } else {
|
|
||||||
- if (group_id == (req_workgroups - 1)) {
|
|
||||||
- req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
|
|
||||||
-
|
|
||||||
- /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
|
|
||||||
- /* Note: req_subgroups % (l_sz/sg_sz) might be 0 */
|
|
||||||
- if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
|
|
||||||
- if ((per_rank_items % (sg_sz - 1)) != 0) {
|
|
||||||
- /* FIXME: */
|
|
||||||
- req_workitems = per_rank_items % (sg_sz - 1);
|
|
||||||
- chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
|
||||||
- }
|
|
||||||
- }
|
|
||||||
- }
|
|
||||||
- }
|
|
||||||
+ auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
|
|
||||||
|
|
||||||
- if (group_id < req_workgroups) {
|
|
||||||
+ if (work_left > 0) {
|
|
||||||
// step 1: push data to next GPU
|
|
||||||
{
|
|
||||||
offset = base + local_world_rank * chunk_sz;
|
|
||||||
--
|
|
||||||
2.34.1
|
|
||||||
|
|
||||||
|
|
||||||
From 1c58cc9ede5ca38138a270f9e5ff59bca74f51d4 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Huajun Li <huajun.li@.com>
|
|
||||||
Date: Wed, 12 Mar 2025 13:29:27 +0800
|
|
||||||
Subject: [PATCH 3/3] Revert "optimize req_workgroup calculate" for hang issue
|
|
||||||
|
|
||||||
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
|
|
||||||
---
|
|
||||||
src/dg2/dg2_allreduce.cpp | 25 +++++++++++++++++++++++--
|
|
||||||
1 file changed, 23 insertions(+), 2 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
|
|
||||||
index 83270ae..15ace74 100644
|
|
||||||
--- a/src/dg2/dg2_allreduce.cpp
|
|
||||||
+++ b/src/dg2/dg2_allreduce.cpp
|
|
||||||
@@ -527,9 +527,30 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
|
|
||||||
auto chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
|
||||||
auto chunk_with_pattern = sg_sz * LS_SZ; /* aligned to 256B */
|
|
||||||
|
|
||||||
- auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
|
|
||||||
+ /* items will be assigned to each rank */
|
|
||||||
+ auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
|
|
||||||
+ auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
|
|
||||||
+ auto req_subgroups = 0;
|
|
||||||
+
|
|
||||||
+ if (req_workgroups >= g_sz/l_sz) {
|
|
||||||
+ req_workgroups = g_sz/l_sz;
|
|
||||||
+ } else {
|
|
||||||
+ if (group_id == (req_workgroups - 1)) {
|
|
||||||
+ req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
|
|
||||||
+
|
|
||||||
+ /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
|
|
||||||
+ /* Note: req_subgroups % (l_sz/sg_sz) might be 0 */
|
|
||||||
+ if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
|
|
||||||
+ if ((per_rank_items % (sg_sz - 1)) != 0) {
|
|
||||||
+ /* FIXME: */
|
|
||||||
+ req_workitems = per_rank_items % (sg_sz - 1);
|
|
||||||
+ chunk_sz = req_workitems * LS_SZ; /* LS_SZ bytes per work-item */
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
+ }
|
|
||||||
|
|
||||||
- if (work_left > 0) {
|
|
||||||
+ if (group_id < req_workgroups) {
|
|
||||||
// step 1: push data to next GPU
|
|
||||||
{
|
|
||||||
offset = base + local_world_rank * chunk_sz;
|
|
||||||
--
|
|
||||||
2.34.1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,19 +54,20 @@ RUN set -eux && \
|
||||||
#
|
#
|
||||||
# Install Intel PyTorch extension for LLM inference
|
# Install Intel PyTorch extension for LLM inference
|
||||||
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
|
pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
|
||||||
|
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||||
#
|
#
|
||||||
# Build torch-ccl
|
# Build torch-ccl
|
||||||
mkdir -p /build && \
|
mkdir -p /build && \
|
||||||
cd /build && \
|
cd /build && \
|
||||||
git clone https://github.com/intel/torch-ccl.git && \
|
git clone https://github.com/intel/torch-ccl.git && \
|
||||||
cd torch-ccl && \
|
cd torch-ccl && \
|
||||||
git checkout ccl_torch2.5.0+xpu && \
|
git checkout ccl_torch2.6.0+xpu && \
|
||||||
git submodule sync && \
|
git submodule sync && \
|
||||||
git submodule update --init --recursive && \
|
git submodule update --init --recursive && \
|
||||||
# This patch will enable build torch-ccl with pytorch 2.6 environment
|
# This patch will enable build torch-ccl with pytorch 2.6 environment
|
||||||
git apply /tmp/ccl_torch.patch && \
|
# git apply /tmp/ccl_torch.patch && \
|
||||||
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel && \
|
USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel && \
|
||||||
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
|
# File path: /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl
|
||||||
# Build oneCCL
|
# Build oneCCL
|
||||||
pip install ninja && \
|
pip install ninja && \
|
||||||
cd /build/ && \
|
cd /build/ && \
|
||||||
|
|
@ -85,7 +86,7 @@ RUN set -eux && \
|
||||||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
|
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
|
||||||
|
|
||||||
# Copy the built torch-ccl package from the build stage
|
# Copy the built torch-ccl package from the build stage
|
||||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
|
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
|
||||||
COPY --from=build /llm/ /llm/
|
COPY --from=build /llm/ /llm/
|
||||||
COPY --from=build /build/oneCCL/build/src/libccl.so.1.0 /opt/intel/1ccl-wks/lib/
|
COPY --from=build /build/oneCCL/build/src/libccl.so.1.0 /opt/intel/1ccl-wks/lib/
|
||||||
COPY --from=build /build/oneCCL/build/src/libccl.so.1 /opt/intel/1ccl-wks/lib/
|
COPY --from=build /build/oneCCL/build/src/libccl.so.1 /opt/intel/1ccl-wks/lib/
|
||||||
|
|
@ -144,9 +145,10 @@ RUN set -eux && \
|
||||||
# Install vllm dependencies
|
# Install vllm dependencies
|
||||||
pip install --upgrade fastapi && \
|
pip install --upgrade fastapi && \
|
||||||
pip install --upgrade "uvicorn[standard]" && \
|
pip install --upgrade "uvicorn[standard]" && \
|
||||||
|
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||||
#
|
#
|
||||||
# Install torch-ccl
|
# Install torch-ccl
|
||||||
pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
|
pip install /opt/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl && \
|
||||||
#
|
#
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
|
apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
|
||||||
|
|
@ -168,21 +170,19 @@ RUN set -eux && \
|
||||||
mkdir -p /llm && \
|
mkdir -p /llm && \
|
||||||
cd /llm && \
|
cd /llm && \
|
||||||
rm -rf /tmp/neo && \
|
rm -rf /tmp/neo && \
|
||||||
# Install intel_extension_for_pytorch
|
|
||||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \
|
|
||||||
pip uninstall -y oneccl oneccl-devel && \
|
|
||||||
pip install intel-opencl-rt==2025.0.2 intel-openmp==2025.0.2 && \
|
|
||||||
#
|
|
||||||
# Install vllm
|
# Install vllm
|
||||||
git clone -b v0.6.6.post1 https://github.com/vllm-project/vllm /llm/vllm && \
|
git clone -b v0.8.3 https://github.com/vllm-project/vllm /llm/vllm && \
|
||||||
cd /llm/vllm && \
|
cd /llm/vllm && \
|
||||||
git apply /llm/vllm_for_multi_arc.patch && \
|
git apply /llm/vllm_for_multi_arc.patch && \
|
||||||
pip install setuptools-scm && \
|
pip install setuptools-scm==8.2.0 setuptools==78.1.0 && \
|
||||||
pip install --upgrade cmake && \
|
pip install --upgrade cmake && \
|
||||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
|
pip install -v -r requirements/xpu.txt && \
|
||||||
|
VLLM_TARGET_DEVICE=xpu python setup.py install && \
|
||||||
|
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
|
||||||
|
pip uninstall -y oneccl oneccl-devel && \
|
||||||
rm -rf /llm/vllm_for_multi_arc.patch && \
|
rm -rf /llm/vllm_for_multi_arc.patch && \
|
||||||
pip install mpi4py fastapi uvicorn openai && \
|
pip install mpi4py fastapi uvicorn openai && \
|
||||||
pip install ray
|
pip install ray numba
|
||||||
|
|
||||||
|
|
||||||
WORKDIR /llm/
|
WORKDIR /llm/
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,9 @@ export TORCH_LLM_ALLREDUCE=0
|
||||||
export CCL_SAME_STREAM=1
|
export CCL_SAME_STREAM=1
|
||||||
export CCL_BLOCKING_WAIT=0
|
export CCL_BLOCKING_WAIT=0
|
||||||
|
|
||||||
|
export VLLM_USE_V1=0
|
||||||
|
export IPEX_LLM_LOWBIT=$LOAD_IN_LOW_BIT
|
||||||
|
|
||||||
source /opt/intel/1ccl-wks/setvars.sh
|
source /opt/intel/1ccl-wks/setvars.sh
|
||||||
|
|
||||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -782,6 +782,9 @@ export USE_XETLA=OFF
|
||||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
|
||||||
export TORCH_LLM_ALLREDUCE=0
|
export TORCH_LLM_ALLREDUCE=0
|
||||||
|
|
||||||
|
export VLLM_USE_V1=0
|
||||||
|
export IPEX_LLM_LOWBIT=fp8
|
||||||
|
|
||||||
source /opt/intel/1ccl-wks/setvars.sh
|
source /opt/intel/1ccl-wks/setvars.sh
|
||||||
|
|
||||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
|
|
@ -793,7 +796,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
--device xpu \
|
--device xpu \
|
||||||
--dtype float16 \
|
--dtype float16 \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--load-in-low-bit fp8 \
|
--load-in-low-bit $IPEX_LLM_LOWBIT \
|
||||||
--max-model-len 2048 \
|
--max-model-len 2048 \
|
||||||
--max-num-batched-tokens 4000 \
|
--max-num-batched-tokens 4000 \
|
||||||
--api-key <your-api-key> \
|
--api-key <your-api-key> \
|
||||||
|
|
|
||||||
|
|
@ -50,9 +50,14 @@ pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorc
|
||||||
pip install setuptools-scm
|
pip install setuptools-scm
|
||||||
pip install --upgrade cmake
|
pip install --upgrade cmake
|
||||||
# cd to your workdir
|
# cd to your workdir
|
||||||
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
|
git clone -b 0.8.3 https://github.com/analytics-zoo/vllm.git
|
||||||
cd vllm
|
cd vllm
|
||||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
|
pip install setuptools-scm==8.2.0 setuptools==78.1.0
|
||||||
|
pip install --upgrade cmake
|
||||||
|
pip install -v -r requirements/xpu.txt
|
||||||
|
VLLM_TARGET_DEVICE=xpu python setup.py install
|
||||||
|
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||||
|
pip uninstall -y oneccl oneccl-devel
|
||||||
# For Qwen model support
|
# For Qwen model support
|
||||||
pip install transformers_stream_generator einops tiktoken
|
pip install transformers_stream_generator einops tiktoken
|
||||||
pip install ray
|
pip install ray
|
||||||
|
|
@ -93,6 +98,8 @@ For vLLM, you can start the service using the following command:
|
||||||
model="YOUR_MODEL_PATH"
|
model="YOUR_MODEL_PATH"
|
||||||
served_model_name="YOUR_MODEL_NAME"
|
served_model_name="YOUR_MODEL_NAME"
|
||||||
export VLLM_RPC_TIMEOUT=100000
|
export VLLM_RPC_TIMEOUT=100000
|
||||||
|
export VLLM_USE_V1=0
|
||||||
|
export IPEX_LLM_LOWBIT=fp8
|
||||||
|
|
||||||
# You may need to adjust the value of
|
# You may need to adjust the value of
|
||||||
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
|
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
|
||||||
|
|
@ -107,7 +114,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
|
||||||
--device xpu \
|
--device xpu \
|
||||||
--dtype float16 \
|
--dtype float16 \
|
||||||
--enforce-eager \
|
--enforce-eager \
|
||||||
--load-in-low-bit sym_int4 \
|
--load-in-low-bit $IPEX_LLM_LOWBIT \
|
||||||
--max-model-len 4096 \
|
--max-model-len 4096 \
|
||||||
--max-num-batched-tokens 10240 \
|
--max-num-batched-tokens 10240 \
|
||||||
--max-num-seqs 12 \
|
--max-num-seqs 12 \
|
||||||
|
|
|
||||||
|
|
@ -150,12 +150,13 @@ def is_linear_module(module):
|
||||||
if _VLLM_VERSION is None:
|
if _VLLM_VERSION is None:
|
||||||
_VLLM_VERSION = get_package_version('vllm')
|
_VLLM_VERSION = get_package_version('vllm')
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
|
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
|
||||||
|
MergedColumnParallelLinear, ReplicatedLinear
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
VLLM_LINEAR_LIST = [
|
VLLM_LINEAR_LIST = [
|
||||||
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
|
ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear, ReplicatedLinear,
|
||||||
]
|
]
|
||||||
if 'xpu' in _VLLM_VERSION:
|
if 'xpu' in _VLLM_VERSION:
|
||||||
VLLM_LINEAR_LIST.append(ParallelLMHead)
|
VLLM_LINEAR_LIST.append(ParallelLMHead)
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
|
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine, IPEXLLMAsyncV1Engine, IPEXLLMLLMV1Engine
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"IPEXLLMAsyncLLMEngine",
|
"IPEXLLMAsyncLLMEngine",
|
||||||
"IPEXLLMLLMEngine",
|
"IPEXLLMLLMEngine",
|
||||||
"IPEXLLMClass",
|
"IPEXLLMClass",
|
||||||
|
"IPEXLLMAsyncV1Engine",
|
||||||
|
"IPEXLLMLLMV1Engine",
|
||||||
"run_mp_engine",
|
"run_mp_engine",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,8 @@ logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||||
|
_is_converted = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -53,13 +55,39 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
|
||||||
) -> "AsyncLLMEngine":
|
) -> "AsyncLLMEngine":
|
||||||
"""Creates an async LLM engine from the engine arguments."""
|
"""Creates an async LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
||||||
start_engine_loop=start_engine_loop,
|
start_engine_loop=start_engine_loop,
|
||||||
usage_context=usage_context, stat_loggers=stat_loggers)
|
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vllm_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
start_engine_loop: bool = True,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
|
||||||
|
disable_log_requests: bool = False,
|
||||||
|
disable_log_stats: bool = False,
|
||||||
|
load_in_low_bit: str = "sym_int4",
|
||||||
|
) -> "AsyncLLMEngine":
|
||||||
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
|
return super().from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
start_engine_loop=start_engine_loop,
|
||||||
|
usage_context=usage_context,
|
||||||
|
stat_loggers=stat_loggers,
|
||||||
|
disable_log_requests=disable_log_requests,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMAsyncV1Engine(AsyncLLM):
|
class IPEXLLMAsyncV1Engine(AsyncLLM):
|
||||||
|
_is_converted = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
@ -74,13 +102,39 @@ class IPEXLLMAsyncV1Engine(AsyncLLM):
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: str = "sym_int4",
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, # noqa
|
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None, # noqa
|
||||||
) -> "AsyncLLM":
|
) -> "AsyncLLM":
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
|
||||||
start_engine_loop=start_engine_loop,
|
start_engine_loop=start_engine_loop,
|
||||||
usage_context=usage_context, stat_loggers=stat_loggers)
|
usage_context=usage_context, stat_loggers=stat_loggers)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vllm_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
start_engine_loop: bool = True,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
|
||||||
|
disable_log_requests: bool = False,
|
||||||
|
disable_log_stats: bool = False,
|
||||||
|
load_in_low_bit: str = "sym_int4",
|
||||||
|
) -> "AsyncLLM":
|
||||||
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
|
return super().from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
start_engine_loop=start_engine_loop,
|
||||||
|
usage_context=usage_context,
|
||||||
|
stat_loggers=stat_loggers,
|
||||||
|
disable_log_requests=disable_log_requests,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMClass(LLM):
|
class IPEXLLMClass(LLM):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
@ -94,20 +148,20 @@ class IPEXLLMClass(LLM):
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
seed: int = 0,
|
seed: Optional[int] = None,
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
swap_space: float = 4,
|
swap_space: float = 4,
|
||||||
cpu_offload_gb: float = 0,
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: Optional[bool] = None,
|
enforce_eager: Optional[bool] = None,
|
||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
disable_async_output_proc: bool = True,
|
disable_async_output_proc: bool = False,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
hf_overrides: Optional[HfOverrides]=None,
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]]=None,
|
mm_processor_kwargs: Optional[dict[str, Any]]=None,
|
||||||
# After positional args are removed, move this right below `model`
|
# After positional args are removed, move this right below `model`
|
||||||
task: TaskOption = "auto",
|
task: TaskOption = "auto",
|
||||||
override_pooler_config: Optional[PoolerConfig] = None,
|
override_pooler_config: Optional[PoolerConfig] = None,
|
||||||
compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
|
compilation_config: Optional[Union[int, dict[str, Any]]]=None,
|
||||||
load_in_low_bit: str = "sym_int4",
|
load_in_low_bit: str = "sym_int4",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -120,6 +174,13 @@ class IPEXLLMClass(LLM):
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
|
|
||||||
|
if "worker_cls" in kwargs:
|
||||||
|
worker_cls = kwargs["worker_cls"]
|
||||||
|
# if the worker_cls is not qualified string name,
|
||||||
|
# we serialize it using cloudpickle to avoid pickling issues
|
||||||
|
if isinstance(worker_cls, type):
|
||||||
|
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
|
||||||
|
|
||||||
if compilation_config is not None:
|
if compilation_config is not None:
|
||||||
if isinstance(compilation_config, (int, dict)):
|
if isinstance(compilation_config, (int, dict)):
|
||||||
compilation_config_instance = CompilationConfig.from_cli(
|
compilation_config_instance = CompilationConfig.from_cli(
|
||||||
|
|
@ -159,11 +220,13 @@ class IPEXLLMClass(LLM):
|
||||||
# Logic to switch between engines is done at runtime instead of import
|
# Logic to switch between engines is done at runtime instead of import
|
||||||
# to avoid import order issues
|
# to avoid import order issues
|
||||||
self.engine_class = self.get_engine_class()
|
self.engine_class = self.get_engine_class()
|
||||||
|
# print("!!! ", load_in_low_bit)
|
||||||
self.llm_engine = self.engine_class.from_engine_args(
|
self.llm_engine = self.engine_class.from_engine_args(
|
||||||
engine_args, usage_context=UsageContext.LLM_CLASS,
|
engine_args, usage_context=UsageContext.LLM_CLASS,
|
||||||
load_in_low_bit=load_in_low_bit)
|
load_in_low_bit=load_in_low_bit)
|
||||||
|
|
||||||
self.request_counter = Counter()
|
self.request_counter = Counter()
|
||||||
|
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_engine_class() -> Type[LLMEngine]:
|
def get_engine_class() -> Type[LLMEngine]:
|
||||||
|
|
@ -173,6 +236,8 @@ class IPEXLLMClass(LLM):
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMLLMV1Engine(V1LLMEngine):
|
class IPEXLLMLLMV1Engine(V1LLMEngine):
|
||||||
|
_is_converted = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -188,14 +253,37 @@ class IPEXLLMLLMV1Engine(V1LLMEngine):
|
||||||
"""Creates an LLM engine from the engine arguments."""
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
|
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
return super().from_engine_args(engine_args,
|
return super().from_engine_args(engine_args,
|
||||||
usage_context,
|
usage_context,
|
||||||
stat_loggers,
|
stat_loggers,
|
||||||
enable_multiprocessing)
|
enable_multiprocessing)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vllm_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||||
|
disable_log_stats: bool = False,
|
||||||
|
load_in_low_bit: str = "sym_int4",
|
||||||
|
) -> "LLMEngine":
|
||||||
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
|
return super().from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
usage_context=usage_context,
|
||||||
|
stat_loggers=stat_loggers,
|
||||||
|
disable_log_stats=disable_log_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMLLMEngine(LLMEngine):
|
class IPEXLLMLLMEngine(LLMEngine):
|
||||||
|
_is_converted = False
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -209,33 +297,89 @@ class IPEXLLMLLMEngine(LLMEngine):
|
||||||
) -> "LLMEngine":
|
) -> "LLMEngine":
|
||||||
"""Creates an LLM engine from the engine arguments."""
|
"""Creates an LLM engine from the engine arguments."""
|
||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
return super().from_engine_args(engine_args, usage_context, stat_loggers)
|
return super().from_engine_args(engine_args, usage_context, stat_loggers)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vllm_config(
|
||||||
|
cls,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
|
stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
|
||||||
|
disable_log_stats: bool = False,
|
||||||
|
load_in_low_bit: str = "sym_int4",
|
||||||
|
) -> "LLMEngine":
|
||||||
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
|
return super().from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
usage_context=usage_context,
|
||||||
|
stat_loggers=stat_loggers,
|
||||||
|
disable_log_stats=disable_log_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPEXLLMMQLLMEngine(MQLLMEngine):
|
class IPEXLLMMQLLMEngine(MQLLMEngine):
|
||||||
|
_is_converted = False
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
def from_engine_args(cls, engine_args: AsyncEngineArgs,
|
||||||
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
|
usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
|
||||||
_ipex_llm_convert(load_in_low_bit)
|
if not cls._is_converted:
|
||||||
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
return super().from_engine_args(engine_args, usage_context, ipc_path)
|
return super().from_engine_args(engine_args, usage_context, ipc_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_vllm_config(cls, vllm_config: VllmConfig,
|
||||||
|
usage_context: UsageContext,
|
||||||
|
disable_log_requests: bool, disable_log_stats: bool,
|
||||||
|
ipc_path: str, load_in_low_bit: str) -> "MQLLMEngine":
|
||||||
|
|
||||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
|
if not cls._is_converted:
|
||||||
ipc_path: str, load_in_low_bit: str, engine_alive):
|
_ipex_llm_convert(load_in_low_bit)
|
||||||
|
cls._is_converted = True
|
||||||
|
return super().from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
ipc_path=ipc_path,
|
||||||
|
usage_context=usage_context,
|
||||||
|
disable_log_requests=disable_log_requests,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
|
)
|
||||||
|
|
||||||
def signal_handler(*_) -> None:
|
from vllm.transformers_utils.config import (
|
||||||
# Interrupt server on sigterm
|
maybe_register_config_serialize_by_value)
|
||||||
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
|
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(*_) -> None:
|
||||||
|
raise KeyboardInterrupt("MQLLMEngine terminated") # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
|
||||||
|
ipc_path: str, disable_log_stats: bool,
|
||||||
|
disable_log_requests: bool, load_in_low_bit, engine_alive):
|
||||||
try:
|
try:
|
||||||
|
# Ensure we can serialize transformer config before spawning
|
||||||
|
maybe_register_config_serialize_by_value()
|
||||||
|
|
||||||
|
engine = IPEXLLMMQLLMEngine.from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
usage_context=usage_context,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
|
disable_log_requests=disable_log_requests,
|
||||||
|
load_in_low_bit=load_in_low_bit,
|
||||||
|
ipc_path=ipc_path)
|
||||||
|
|
||||||
signal.signal(signal.SIGTERM, signal_handler)
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
|
|
||||||
usage_context=usage_context,
|
|
||||||
ipc_path=ipc_path,
|
|
||||||
load_in_low_bit=load_in_low_bit)
|
|
||||||
engine.start()
|
engine.start()
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
engine_alive.value = False
|
engine_alive.value = False
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import atexit
|
import atexit
|
||||||
|
import gc
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
@ -10,16 +13,18 @@ import socket
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import AsyncIterator, Optional, Set, Tuple
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import APIRouter, FastAPI, Request
|
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
from starlette.datastructures import State
|
from starlette.datastructures import State
|
||||||
from starlette.routing import Mount
|
from starlette.routing import Mount
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
@ -27,17 +32,17 @@ from typing_extensions import assert_never
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
|
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine # type: ignore
|
||||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||||
from ipex_llm.vllm.xpu.engine import run_mp_engine
|
from ipex_llm.vllm.xpu.engine import run_mp_engine
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import load_chat_template
|
from vllm.entrypoints.chat_utils import (load_chat_template,
|
||||||
|
resolve_hf_chat_template,
|
||||||
|
resolve_mistral_chat_template)
|
||||||
from vllm.entrypoints.launcher import serve_http
|
from vllm.entrypoints.launcher import serve_http
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||||
validate_parsed_serve_args)
|
validate_parsed_serve_args)
|
||||||
|
|
||||||
# from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
|
@ -46,33 +51,46 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DetokenizeRequest,
|
DetokenizeRequest,
|
||||||
DetokenizeResponse,
|
DetokenizeResponse,
|
||||||
|
EmbeddingChatRequest,
|
||||||
|
EmbeddingCompletionRequest,
|
||||||
EmbeddingRequest,
|
EmbeddingRequest,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
EmbeddingResponseData,
|
EmbeddingResponseData,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
LoadLoraAdapterRequest,
|
LoadLoRAAdapterRequest,
|
||||||
|
PoolingChatRequest,
|
||||||
|
PoolingCompletionRequest,
|
||||||
PoolingRequest, PoolingResponse,
|
PoolingRequest, PoolingResponse,
|
||||||
|
RerankRequest, RerankResponse,
|
||||||
ScoreRequest, ScoreResponse,
|
ScoreRequest, ScoreResponse,
|
||||||
TokenizeRequest,
|
TokenizeRequest,
|
||||||
TokenizeResponse,
|
TokenizeResponse,
|
||||||
UnloadLoraAdapterRequest)
|
TranscriptionRequest,
|
||||||
|
TranscriptionResponse,
|
||||||
|
UnloadLoRAAdapterRequest)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
OpenAIServingModels)
|
OpenAIServingModels)
|
||||||
|
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
|
||||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||||
from vllm.entrypoints.openai.serving_tokenization import (
|
from vllm.entrypoints.openai.serving_tokenization import (
|
||||||
OpenAIServingTokenization)
|
OpenAIServingTokenization)
|
||||||
|
from vllm.entrypoints.openai.serving_transcription import (
|
||||||
|
OpenAIServingTranscription)
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.entrypoints.utils import with_cancellation
|
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||||
|
with_cancellation)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.reasoning import ReasoningParserManager
|
||||||
|
from vllm.transformers_utils.config import (
|
||||||
|
maybe_register_config_serialize_by_value)
|
||||||
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
|
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
||||||
is_valid_ipv6_address, set_ulimit)
|
is_valid_ipv6_address, set_ulimit)
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
|
|
@ -83,7 +101,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||||
logger = init_logger('vllm.entrypoints.openai.api_server')
|
logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||||
|
|
||||||
_running_tasks: Set[asyncio.Task] = set()
|
_running_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
@ -102,6 +120,11 @@ async def lifespan(app: FastAPI):
|
||||||
task.add_done_callback(_running_tasks.remove)
|
task.add_done_callback(_running_tasks.remove)
|
||||||
else:
|
else:
|
||||||
task = None
|
task = None
|
||||||
|
|
||||||
|
# Mark the startup heap as static so that it's ignored by GC.
|
||||||
|
# Reduces pause times of oldest generation collections.
|
||||||
|
gc.collect()
|
||||||
|
gc.freeze()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -139,24 +162,49 @@ async def build_async_engine_client_from_engine_args(
|
||||||
Returns the Client or None if the creation failed.
|
Returns the Client or None if the creation failed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Fall back
|
# Create the EngineConfig (determines if we can use V1).
|
||||||
# TODO: fill out feature matrix.
|
usage_context = UsageContext.OPENAI_API_SERVER
|
||||||
if (MQLLMEngineClient.is_unsupported_config(engine_args)
|
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||||
or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
|
|
||||||
|
# V1 AsyncLLM.
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
if disable_frontend_multiprocessing:
|
||||||
|
logger.warning(
|
||||||
|
"V1 is enabled, but got --disable-frontend-multiprocessing. "
|
||||||
|
"To disable frontend multiprocessing, set VLLM_USE_V1=0.")
|
||||||
|
|
||||||
|
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncV1Engine as AsyncLLM
|
||||||
|
async_llm: Optional[AsyncLLM] = None
|
||||||
|
try:
|
||||||
|
async_llm = AsyncLLM.from_vllm_config(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
usage_context=usage_context,
|
||||||
|
disable_log_requests=engine_args.disable_log_requests,
|
||||||
|
disable_log_stats=engine_args.disable_log_stats,
|
||||||
|
load_in_low_bit=load_in_low_bit)
|
||||||
|
yield async_llm
|
||||||
|
finally:
|
||||||
|
if async_llm:
|
||||||
|
async_llm.shutdown()
|
||||||
|
|
||||||
|
# V0 AsyncLLM.
|
||||||
|
elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
|
||||||
|
or disable_frontend_multiprocessing):
|
||||||
|
|
||||||
engine_client: Optional[EngineClient] = None
|
engine_client: Optional[EngineClient] = None
|
||||||
try:
|
try:
|
||||||
# When starting this, we are actually starting with the V1Engine
|
engine_client = AsyncLLMEngine.from_vllm_config(
|
||||||
# Here we are doing a classification, we will need to do this in IPEX-LLM
|
vllm_config=vllm_config,
|
||||||
engine_client = AsyncLLMEngine.from_engine_args(
|
usage_context=usage_context,
|
||||||
engine_args=engine_args,
|
disable_log_requests=engine_args.disable_log_requests,
|
||||||
usage_context=UsageContext.OPENAI_API_SERVER,
|
disable_log_stats=engine_args.disable_log_stats,
|
||||||
load_in_low_bit=load_in_low_bit)
|
load_in_low_bit=load_in_low_bit)
|
||||||
yield engine_client
|
yield engine_client
|
||||||
finally:
|
finally:
|
||||||
if engine_client and hasattr(engine_client, "shutdown"):
|
if engine_client and hasattr(engine_client, "shutdown"):
|
||||||
engine_client.shutdown()
|
engine_client.shutdown()
|
||||||
|
|
||||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
# V0MQLLMEngine.
|
||||||
else:
|
else:
|
||||||
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
|
||||||
# Make TemporaryDirectory for prometheus multiprocessing
|
# Make TemporaryDirectory for prometheus multiprocessing
|
||||||
|
|
@ -183,14 +231,18 @@ async def build_async_engine_client_from_engine_args(
|
||||||
# so we need to spawn a new process
|
# so we need to spawn a new process
|
||||||
context = multiprocessing.get_context("spawn")
|
context = multiprocessing.get_context("spawn")
|
||||||
|
|
||||||
|
# Ensure we can serialize transformer config before spawning
|
||||||
|
maybe_register_config_serialize_by_value()
|
||||||
|
|
||||||
# The Process can raise an exception during startup, which may
|
# The Process can raise an exception during startup, which may
|
||||||
# not actually result in an exitcode being reported. As a result
|
# not actually result in an exitcode being reported. As a result
|
||||||
# we use a shared variable to communicate the information.
|
# we use a shared variable to communicate the information.
|
||||||
engine_alive = multiprocessing.Value('b', True, lock=False)
|
engine_alive = multiprocessing.Value('b', True, lock=False)
|
||||||
engine_process = context.Process(target=run_mp_engine,
|
engine_process = context.Process(
|
||||||
args=(engine_args,
|
target=run_mp_engine,
|
||||||
UsageContext.OPENAI_API_SERVER,
|
args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
|
||||||
ipc_path, load_in_low_bit, engine_alive))
|
engine_args.disable_log_stats,
|
||||||
|
engine_args.disable_log_requests, load_in_low_bit, engine_alive))
|
||||||
engine_process.start()
|
engine_process.start()
|
||||||
engine_pid = engine_process.pid
|
engine_pid = engine_process.pid
|
||||||
assert engine_pid is not None, "Engine process failed to start."
|
assert engine_pid is not None, "Engine process failed to start."
|
||||||
|
|
@ -205,8 +257,7 @@ async def build_async_engine_client_from_engine_args(
|
||||||
atexit.register(_cleanup_ipc_path)
|
atexit.register(_cleanup_ipc_path)
|
||||||
|
|
||||||
# Build RPCClient, which conforms to EngineClient Protocol.
|
# Build RPCClient, which conforms to EngineClient Protocol.
|
||||||
engine_config = engine_args.create_engine_config()
|
build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
|
||||||
build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
|
|
||||||
engine_pid)
|
engine_pid)
|
||||||
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
mq_engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||||
None, build_client)
|
None, build_client)
|
||||||
|
|
@ -244,6 +295,43 @@ async def build_async_engine_client_from_engine_args(
|
||||||
multiprocess.mark_process_dead(engine_process.pid)
|
multiprocess.mark_process_dead(engine_process.pid)
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_json_request(raw_request: Request):
|
||||||
|
content_type = raw_request.headers.get("content-type", "").lower()
|
||||||
|
media_type = content_type.split(";", maxsplit=1)[0]
|
||||||
|
if media_type != "application/json":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
|
||||||
|
detail="Unsupported Media Type: Only 'application/json' is allowed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
save_dict = {}
|
||||||
|
import os
|
||||||
|
flag = os.getenv("VLLM_LOG_OUTPUT", None)
|
||||||
|
async def stream_generator(generator, request, request_id):
|
||||||
|
async for chunk in generator:
|
||||||
|
if request_id not in save_dict:
|
||||||
|
save_dict[request_id] = ""
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
data = chunk.strip()
|
||||||
|
if data.startswith('data: '):
|
||||||
|
data = data[len('data: '):]
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
json_data = json.loads(data)
|
||||||
|
if 'choices' in json_data and len(json_data['choices']) > 0:
|
||||||
|
choice = json_data['choices'][0]
|
||||||
|
if 'delta' in choice:
|
||||||
|
save_dict[request_id] += choice["delta"]["content"]
|
||||||
|
elif 'text' in choice:
|
||||||
|
save_dict[request_id] += choice["text"]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
|
||||||
|
pass # Done
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -254,6 +342,7 @@ def mount_metrics(app: FastAPI):
|
||||||
# See https://prometheus.github.io/client_python/multiprocess/
|
# See https://prometheus.github.io/client_python/multiprocess/
|
||||||
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
from prometheus_client import (CollectorRegistry, make_asgi_app,
|
||||||
multiprocess)
|
multiprocess)
|
||||||
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
|
|
||||||
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
|
||||||
if prometheus_multiproc_dir_path is not None:
|
if prometheus_multiproc_dir_path is not None:
|
||||||
|
|
@ -261,6 +350,16 @@ def mount_metrics(app: FastAPI):
|
||||||
prometheus_multiproc_dir_path)
|
prometheus_multiproc_dir_path)
|
||||||
registry = CollectorRegistry()
|
registry = CollectorRegistry()
|
||||||
multiprocess.MultiProcessCollector(registry)
|
multiprocess.MultiProcessCollector(registry)
|
||||||
|
Instrumentator(
|
||||||
|
excluded_handlers=[
|
||||||
|
"/metrics",
|
||||||
|
"/health",
|
||||||
|
"/load",
|
||||||
|
"/ping",
|
||||||
|
"/version",
|
||||||
|
],
|
||||||
|
registry=registry,
|
||||||
|
).add().instrument(app).expose(app)
|
||||||
|
|
||||||
# Add prometheus asgi middleware to route /metrics requests
|
# Add prometheus asgi middleware to route /metrics requests
|
||||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||||
|
|
@ -298,7 +397,11 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||||
return request.app.state.openai_serving_embedding
|
return request.app.state.openai_serving_embedding
|
||||||
|
|
||||||
|
|
||||||
def score(request: Request) -> Optional[OpenAIServingScores]:
|
def score(request: Request) -> Optional[ServingScores]:
|
||||||
|
return request.app.state.openai_serving_scores
|
||||||
|
|
||||||
|
|
||||||
|
def rerank(request: Request) -> Optional[ServingScores]:
|
||||||
return request.app.state.openai_serving_scores
|
return request.app.state.openai_serving_scores
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -306,6 +409,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||||
return request.app.state.openai_serving_tokenization
|
return request.app.state.openai_serving_tokenization
|
||||||
|
|
||||||
|
|
||||||
|
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||||
|
return request.app.state.openai_serving_transcription
|
||||||
|
|
||||||
|
|
||||||
def engine_client(request: Request) -> EngineClient:
|
def engine_client(request: Request) -> EngineClient:
|
||||||
return request.app.state.engine_client
|
return request.app.state.engine_client
|
||||||
|
|
||||||
|
|
@ -317,7 +424,31 @@ async def health(raw_request: Request) -> Response:
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/tokenize")
|
@router.get("/load")
|
||||||
|
async def get_server_load_metrics(request: Request):
|
||||||
|
# This endpoint returns the current server load metrics.
|
||||||
|
# It tracks requests utilizing the GPU from the following routes:
|
||||||
|
# - /v1/chat/completions
|
||||||
|
# - /v1/completions
|
||||||
|
# - /v1/audio/transcriptions
|
||||||
|
# - /v1/embeddings
|
||||||
|
# - /pooling
|
||||||
|
# - /score
|
||||||
|
# - /v1/score
|
||||||
|
# - /rerank
|
||||||
|
# - /v1/rerank
|
||||||
|
# - /v2/rerank
|
||||||
|
return JSONResponse(
|
||||||
|
content={'server_load': request.app.state.server_load_metrics})
|
||||||
|
|
||||||
|
|
||||||
|
@router.api_route("/ping", methods=["GET", "POST"])
|
||||||
|
async def ping(raw_request: Request) -> Response:
|
||||||
|
"""Ping check. Endpoint required for SageMaker"""
|
||||||
|
return await health(raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
handler = tokenization(raw_request)
|
handler = tokenization(raw_request)
|
||||||
|
|
@ -332,7 +463,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/detokenize")
|
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||||
handler = tokenization(raw_request)
|
handler = tokenization(raw_request)
|
||||||
|
|
@ -361,35 +492,10 @@ async def show_version():
|
||||||
return JSONResponse(content=ver)
|
return JSONResponse(content=ver)
|
||||||
|
|
||||||
|
|
||||||
save_dict = {}
|
@router.post("/v1/chat/completions",
|
||||||
import os
|
dependencies=[Depends(validate_json_request)])
|
||||||
flag = os.getenv("VLLM_LOG_OUTPUT", None)
|
|
||||||
async def stream_generator(generator, request, request_id):
|
|
||||||
async for chunk in generator:
|
|
||||||
if request_id not in save_dict:
|
|
||||||
save_dict[request_id] = ""
|
|
||||||
import json
|
|
||||||
try:
|
|
||||||
data = chunk.strip()
|
|
||||||
if data.startswith('data: '):
|
|
||||||
data = data[len('data: '):]
|
|
||||||
else:
|
|
||||||
yield chunk
|
|
||||||
json_data = json.loads(data)
|
|
||||||
if 'choices' in json_data and len(json_data['choices']) > 0:
|
|
||||||
choice = json_data['choices'][0]
|
|
||||||
if 'delta' in choice:
|
|
||||||
save_dict[request_id] += choice["delta"]["content"]
|
|
||||||
elif 'text' in choice:
|
|
||||||
save_dict[request_id] += choice["text"]
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
|
|
||||||
pass # Done
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_chat_completion(request: ChatCompletionRequest,
|
async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
handler = chat(raw_request)
|
handler = chat(raw_request)
|
||||||
|
|
@ -401,7 +507,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
request_id = "chatcmpl-" \
|
request_id = "chatcmpl-" \
|
||||||
f"{handler._base_request_id(raw_request, request.request_id)}"
|
f"{handler._base_request_id(raw_request, request.request_id)}"
|
||||||
print(f"First received request_id: {request_id}, request: {request}")
|
print(f"First received request_id: {request_id}, request: {request}")
|
||||||
|
|
||||||
generator = await handler.create_chat_completion(request, raw_request)
|
generator = await handler.create_chat_completion(request, raw_request)
|
||||||
|
|
||||||
if isinstance(generator, ErrorResponse):
|
if isinstance(generator, ErrorResponse):
|
||||||
|
|
@ -418,8 +523,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/completions")
|
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
handler = completion(raw_request)
|
handler = completion(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
|
|
@ -438,14 +544,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||||
if flag is not None:
|
if flag is not None:
|
||||||
print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
|
print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
|
||||||
return JSONResponse(content=generator.model_dump())
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
if flag is not None:
|
if flag is not None:
|
||||||
return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
|
return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
|
||||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/embeddings")
|
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
handler = embedding(raw_request)
|
handler = embedding(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
|
|
@ -460,6 +567,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
"use the Pooling API (`/pooling`) instead.")
|
"use the Pooling API (`/pooling`) instead.")
|
||||||
|
|
||||||
res = await fallback_handler.create_pooling(request, raw_request)
|
res = await fallback_handler.create_pooling(request, raw_request)
|
||||||
|
|
||||||
|
generator: Union[ErrorResponse, EmbeddingResponse]
|
||||||
if isinstance(res, PoolingResponse):
|
if isinstance(res, PoolingResponse):
|
||||||
generator = EmbeddingResponse(
|
generator = EmbeddingResponse(
|
||||||
id=res.id,
|
id=res.id,
|
||||||
|
|
@ -488,8 +597,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/pooling")
|
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||||
handler = pooling(raw_request)
|
handler = pooling(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
|
|
@ -506,8 +616,9 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/score")
|
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||||
handler = score(raw_request)
|
handler = score(raw_request)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
|
|
@ -524,8 +635,9 @@ async def create_score(request: ScoreRequest, raw_request: Request):
|
||||||
assert_never(generator)
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/score")
|
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"To indicate that Score API is not part of standard OpenAI API, we "
|
"To indicate that Score API is not part of standard OpenAI API, we "
|
||||||
|
|
@ -534,6 +646,160 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||||
return await create_score(request, raw_request)
|
return await create_score(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/audio/transcriptions")
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def create_transcriptions(request: Annotated[TranscriptionRequest,
|
||||||
|
Form()],
|
||||||
|
raw_request: Request):
|
||||||
|
handler = transcription(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Transcriptions API")
|
||||||
|
|
||||||
|
audio_data = await request.file.read()
|
||||||
|
generator = await handler.create_transcription(audio_data, request,
|
||||||
|
raw_request)
|
||||||
|
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
|
||||||
|
elif isinstance(generator, TranscriptionResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
|
||||||
|
@with_cancellation
|
||||||
|
@load_aware_call
|
||||||
|
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||||
|
handler = rerank(raw_request)
|
||||||
|
if handler is None:
|
||||||
|
return base(raw_request).create_error_response(
|
||||||
|
message="The model does not support Rerank (Score) API")
|
||||||
|
generator = await handler.do_rerank(request, raw_request)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(),
|
||||||
|
status_code=generator.code)
|
||||||
|
elif isinstance(generator, RerankResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
assert_never(generator)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
|
||||||
|
@with_cancellation
|
||||||
|
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||||
|
logger.warning_once(
|
||||||
|
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||||
|
" API, we have located it at `/rerank`. Please update your client "
|
||||||
|
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
||||||
|
|
||||||
|
return await do_rerank(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
|
||||||
|
@with_cancellation
|
||||||
|
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||||
|
return await do_rerank(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
|
||||||
|
"generate": {
|
||||||
|
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||||
|
"default": (CompletionRequest, create_completion),
|
||||||
|
},
|
||||||
|
"embed": {
|
||||||
|
"messages": (EmbeddingChatRequest, create_embedding),
|
||||||
|
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||||
|
},
|
||||||
|
"score": {
|
||||||
|
"default": (RerankRequest, do_rerank)
|
||||||
|
},
|
||||||
|
"rerank": {
|
||||||
|
"default": (RerankRequest, do_rerank)
|
||||||
|
},
|
||||||
|
"reward": {
|
||||||
|
"messages": (PoolingChatRequest, create_pooling),
|
||||||
|
"default": (PoolingCompletionRequest, create_pooling),
|
||||||
|
},
|
||||||
|
"classify": {
|
||||||
|
"messages": (PoolingChatRequest, create_pooling),
|
||||||
|
"default": (PoolingCompletionRequest, create_pooling),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if envs.VLLM_SERVER_DEV_MODE:
|
||||||
|
|
||||||
|
@router.post("/reset_prefix_cache")
|
||||||
|
async def reset_prefix_cache(raw_request: Request):
|
||||||
|
"""
|
||||||
|
Reset the prefix cache. Note that we currently do not check if the
|
||||||
|
prefix cache is successfully reset in the API server.
|
||||||
|
"""
|
||||||
|
device = None
|
||||||
|
device_str = raw_request.query_params.get("device")
|
||||||
|
if device_str is not None:
|
||||||
|
device = Device[device_str.upper()]
|
||||||
|
logger.info("Resetting prefix cache with specific %s...", str(device))
|
||||||
|
await engine_client(raw_request).reset_prefix_cache(device)
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@router.post("/sleep")
|
||||||
|
async def sleep(raw_request: Request):
|
||||||
|
# get POST params
|
||||||
|
level = raw_request.query_params.get("level", "1")
|
||||||
|
await engine_client(raw_request).sleep(int(level))
|
||||||
|
# FIXME: in v0 with frontend multiprocessing, the sleep command
|
||||||
|
# is sent but does not finish yet when we return a response.
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@router.post("/wake_up")
|
||||||
|
async def wake_up(raw_request: Request):
|
||||||
|
tags = raw_request.query_params.getlist("tags")
|
||||||
|
if tags == []:
|
||||||
|
# set to None to wake up all tags if no tags are provided
|
||||||
|
tags = None
|
||||||
|
logger.info("wake up the engine with tags: %s", tags)
|
||||||
|
await engine_client(raw_request).wake_up(tags)
|
||||||
|
# FIXME: in v0 with frontend multiprocessing, the wake-up command
|
||||||
|
# is sent but does not finish yet when we return a response.
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
@router.get("/is_sleeping")
|
||||||
|
async def is_sleeping(raw_request: Request):
|
||||||
|
logger.info("check whether the engine is sleeping")
|
||||||
|
is_sleeping = await engine_client(raw_request).is_sleeping()
|
||||||
|
return JSONResponse(content={"is_sleeping": is_sleeping})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
|
||||||
|
async def invocations(raw_request: Request):
|
||||||
|
"""
|
||||||
|
For SageMaker, routes requests to other handlers based on model `task`.
|
||||||
|
"""
|
||||||
|
body = await raw_request.json()
|
||||||
|
task = raw_request.app.state.task
|
||||||
|
|
||||||
|
if task not in TASK_HANDLERS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
||||||
|
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
||||||
|
|
||||||
|
handler_config = TASK_HANDLERS[task]
|
||||||
|
if "messages" in body:
|
||||||
|
request_model, handler = handler_config["messages"]
|
||||||
|
else:
|
||||||
|
request_model, handler = handler_config["default"]
|
||||||
|
|
||||||
|
# this is required since we lose the FastAPI automatic casting
|
||||||
|
request = request_model.model_validate(body)
|
||||||
|
return await handler(request, raw_request)
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
|
|
@ -556,32 +822,30 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
|
||||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Lora dynamic loading & unloading is enabled in the API server. "
|
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||||
"This should ONLY be used for local development!")
|
"This should ONLY be used for local development!")
|
||||||
|
|
||||||
@router.post("/v1/load_lora_adapter")
|
@router.post("/v1/load_lora_adapter",
|
||||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
dependencies=[Depends(validate_json_request)])
|
||||||
|
async def load_lora_adapter(request: LoadLoRAAdapterRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
for route in [chat, completion, embedding]:
|
handler = models(raw_request)
|
||||||
handler = route(raw_request)
|
response = await handler.load_lora_adapter(request)
|
||||||
if handler is not None:
|
if isinstance(response, ErrorResponse):
|
||||||
response = await handler.load_lora_adapter(request)
|
return JSONResponse(content=response.model_dump(),
|
||||||
if isinstance(response, ErrorResponse):
|
status_code=response.code)
|
||||||
return JSONResponse(content=response.model_dump(),
|
|
||||||
status_code=response.code)
|
|
||||||
|
|
||||||
return Response(status_code=200, content=response)
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
@router.post("/v1/unload_lora_adapter")
|
@router.post("/v1/unload_lora_adapter",
|
||||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
dependencies=[Depends(validate_json_request)])
|
||||||
|
async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
|
||||||
raw_request: Request):
|
raw_request: Request):
|
||||||
for route in [chat, completion, embedding]:
|
handler = models(raw_request)
|
||||||
handler = route(raw_request)
|
response = await handler.unload_lora_adapter(request)
|
||||||
if handler is not None:
|
if isinstance(response, ErrorResponse):
|
||||||
response = await handler.unload_lora_adapter(request)
|
return JSONResponse(content=response.model_dump(),
|
||||||
if isinstance(response, ErrorResponse):
|
status_code=response.code)
|
||||||
return JSONResponse(content=response.model_dump(),
|
|
||||||
status_code=response.code)
|
|
||||||
|
|
||||||
return Response(status_code=200, content=response)
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
|
@ -615,7 +879,8 @@ def build_app(args: Namespace) -> FastAPI:
|
||||||
return JSONResponse(err.model_dump(),
|
return JSONResponse(err.model_dump(),
|
||||||
status_code=HTTPStatus.BAD_REQUEST)
|
status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
if token := envs.VLLM_API_KEY or args.api_key:
|
# Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
|
||||||
|
if token := args.api_key or envs.VLLM_API_KEY:
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def authentication(request: Request, call_next):
|
async def authentication(request: Request, call_next):
|
||||||
|
|
@ -644,11 +909,26 @@ def build_app(args: Namespace) -> FastAPI:
|
||||||
response.headers["X-Request-Id"] = request_id
|
response.headers["X-Request-Id"] = request_id
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
|
||||||
|
logger.warning("CAUTION: Enabling log response in the API Server. "
|
||||||
|
"This can include sensitive information and should be "
|
||||||
|
"avoided in production.")
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def log_response(request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
response_body = [
|
||||||
|
section async for section in response.body_iterator
|
||||||
|
]
|
||||||
|
response.body_iterator = iterate_in_threadpool(iter(response_body))
|
||||||
|
logger.info("response_body={%s}", response_body[0].decode())
|
||||||
|
return response
|
||||||
|
|
||||||
for middleware in args.middleware:
|
for middleware in args.middleware:
|
||||||
module_path, object_name = middleware.rsplit(".", 1)
|
module_path, object_name = middleware.rsplit(".", 1)
|
||||||
imported = getattr(importlib.import_module(module_path), object_name)
|
imported = getattr(importlib.import_module(module_path), object_name)
|
||||||
if inspect.isclass(imported):
|
if inspect.isclass(imported):
|
||||||
app.add_middleware(imported)
|
app.add_middleware(imported) # type: ignore[arg-type]
|
||||||
elif inspect.iscoroutinefunction(imported):
|
elif inspect.iscoroutinefunction(imported):
|
||||||
app.middleware("http")(imported)
|
app.middleware("http")(imported)
|
||||||
else:
|
else:
|
||||||
|
|
@ -658,7 +938,7 @@ def build_app(args: Namespace) -> FastAPI:
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
def init_app_state(
|
async def init_app_state(
|
||||||
engine_client: EngineClient,
|
engine_client: EngineClient,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
state: State,
|
state: State,
|
||||||
|
|
@ -683,15 +963,36 @@ def init_app_state(
|
||||||
state.log_stats = not args.disable_log_stats
|
state.log_stats = not args.disable_log_stats
|
||||||
|
|
||||||
resolved_chat_template = load_chat_template(args.chat_template)
|
resolved_chat_template = load_chat_template(args.chat_template)
|
||||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
if resolved_chat_template is not None:
|
||||||
|
# Get the tokenizer to check official template
|
||||||
|
tokenizer = await engine_client.get_tokenizer()
|
||||||
|
|
||||||
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
|
# The warning is logged in resolve_mistral_chat_template.
|
||||||
|
resolved_chat_template = resolve_mistral_chat_template(
|
||||||
|
chat_template=resolved_chat_template)
|
||||||
|
else:
|
||||||
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code)
|
||||||
|
|
||||||
|
if hf_chat_template != resolved_chat_template:
|
||||||
|
logger.warning(
|
||||||
|
"Using supplied chat template: %s\n"
|
||||||
|
"It is different from official chat template '%s'. "
|
||||||
|
"This discrepancy may lead to performance degradation.",
|
||||||
|
resolved_chat_template, args.model)
|
||||||
|
|
||||||
state.openai_serving_models = OpenAIServingModels(
|
state.openai_serving_models = OpenAIServingModels(
|
||||||
|
engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
)
|
)
|
||||||
# TODO: The chat template is now broken for lora adapters :(
|
await state.openai_serving_models.init_static_loras()
|
||||||
state.openai_serving_chat = OpenAIServingChat(
|
state.openai_serving_chat = OpenAIServingChat(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
|
|
@ -703,6 +1004,8 @@ def init_app_state(
|
||||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||||
enable_auto_tools=args.enable_auto_tool_choice,
|
enable_auto_tools=args.enable_auto_tool_choice,
|
||||||
tool_parser=args.tool_call_parser,
|
tool_parser=args.tool_call_parser,
|
||||||
|
enable_reasoning=args.enable_reasoning,
|
||||||
|
reasoning_parser=args.reasoning_parser,
|
||||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||||
) if model_config.runner_type == "generate" else None
|
) if model_config.runner_type == "generate" else None
|
||||||
state.openai_serving_completion = OpenAIServingCompletion(
|
state.openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
|
@ -728,7 +1031,13 @@ def init_app_state(
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
) if model_config.task == "embed" else None
|
) if model_config.task == "embed" else None
|
||||||
state.openai_serving_scores = OpenAIServingScores(
|
state.openai_serving_scores = ServingScores(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
state.openai_serving_models,
|
||||||
|
request_logger=request_logger) if model_config.task in (
|
||||||
|
"score", "embed", "pooling") else None
|
||||||
|
state.jinaai_serving_reranking = ServingScores(
|
||||||
engine_client,
|
engine_client,
|
||||||
model_config,
|
model_config,
|
||||||
state.openai_serving_models,
|
state.openai_serving_models,
|
||||||
|
|
@ -742,92 +1051,26 @@ def init_app_state(
|
||||||
chat_template=resolved_chat_template,
|
chat_template=resolved_chat_template,
|
||||||
chat_template_content_format=args.chat_template_content_format,
|
chat_template_content_format=args.chat_template_content_format,
|
||||||
)
|
)
|
||||||
|
state.openai_serving_transcription = OpenAIServingTranscription(
|
||||||
|
engine_client,
|
||||||
|
model_config,
|
||||||
|
state.openai_serving_models,
|
||||||
|
request_logger=request_logger,
|
||||||
|
) if model_config.runner_type == "transcription" else None
|
||||||
state.task = model_config.task
|
state.task = model_config.task
|
||||||
# if args.served_model_name is not None:
|
|
||||||
# served_model_names = args.served_model_name
|
|
||||||
# else:
|
|
||||||
# served_model_names = [args.model]
|
|
||||||
|
|
||||||
# if args.disable_log_requests:
|
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||||
# request_logger = None
|
state.server_load_metrics = 0
|
||||||
# else:
|
|
||||||
# request_logger = RequestLogger(max_log_len=args.max_log_len)
|
|
||||||
|
|
||||||
# base_model_paths = [
|
|
||||||
# BaseModelPath(name=name, model_path=args.model)
|
|
||||||
# for name in served_model_names
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# state.engine_client = engine_client
|
|
||||||
# state.log_stats = not args.disable_log_stats
|
|
||||||
|
|
||||||
# resolved_chat_template = load_chat_template(args.chat_template)
|
|
||||||
# logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
|
||||||
|
|
||||||
# state.openai_serving_chat = OpenAIServingChat(
|
|
||||||
# engine_client,
|
|
||||||
# model_config,
|
|
||||||
# base_model_paths,
|
|
||||||
# args.response_role,
|
|
||||||
# lora_modules=args.lora_modules,
|
|
||||||
# prompt_adapters=args.prompt_adapters,
|
|
||||||
# request_logger=request_logger,
|
|
||||||
# chat_template=resolved_chat_template,
|
|
||||||
# chat_template_content_format=args.chat_template_content_format,
|
|
||||||
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
|
||||||
# enable_auto_tools=args.enable_auto_tool_choice,
|
|
||||||
# tool_parser=args.tool_call_parser,
|
|
||||||
# enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
|
||||||
# ) if model_config.runner_type == "generate" else None
|
|
||||||
# state.openai_serving_completion = OpenAIServingCompletion(
|
|
||||||
# engine_client,
|
|
||||||
# model_config,
|
|
||||||
# base_model_paths,
|
|
||||||
# lora_modules=args.lora_modules,
|
|
||||||
# prompt_adapters=args.prompt_adapters,
|
|
||||||
# request_logger=request_logger,
|
|
||||||
# return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
|
||||||
# ) if model_config.runner_type == "generate" else None
|
|
||||||
# state.openai_serving_pooling = OpenAIServingPooling(
|
|
||||||
# engine_client,
|
|
||||||
# model_config,
|
|
||||||
# base_model_paths,
|
|
||||||
# request_logger=request_logger,
|
|
||||||
# chat_template=resolved_chat_template,
|
|
||||||
# chat_template_content_format=args.chat_template_content_format,
|
|
||||||
# ) if model_config.runner_type == "pooling" else None
|
|
||||||
# state.openai_serving_embedding = OpenAIServingEmbedding(
|
|
||||||
# engine_client,
|
|
||||||
# model_config,
|
|
||||||
# base_model_paths,
|
|
||||||
# request_logger=request_logger,
|
|
||||||
# chat_template=resolved_chat_template,
|
|
||||||
# chat_template_content_format=args.chat_template_content_format,
|
|
||||||
# ) if model_config.task == "embed" else None
|
|
||||||
# state.openai_serving_scores = OpenAIServingScores(
|
|
||||||
# engine_client,
|
|
||||||
# model_config,
|
|
||||||
# base_model_paths,
|
|
||||||
# request_logger=request_logger
|
|
||||||
# ) if model_config.task == "score" else None
|
|
||||||
# state.openai_serving_tokenization = OpenAIServingTokenization(
|
|
||||||
# engine_client,
|
|
||||||
# model_config,
|
|
||||||
# base_model_paths,
|
|
||||||
# lora_modules=args.lora_modules,
|
|
||||||
# request_logger=request_logger,
|
|
||||||
# chat_template=resolved_chat_template,
|
|
||||||
# chat_template_content_format=args.chat_template_content_format,
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
|
||||||
family = socket.AF_INET
|
family = socket.AF_INET
|
||||||
if is_valid_ipv6_address(addr[0]):
|
if is_valid_ipv6_address(addr[0]):
|
||||||
family = socket.AF_INET6
|
family = socket.AF_INET6
|
||||||
|
|
||||||
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
sock = socket.socket(family=family, type=socket.SOCK_STREAM)
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||||
sock.bind(addr)
|
sock.bind(addr)
|
||||||
|
|
||||||
return sock
|
return sock
|
||||||
|
|
@ -840,11 +1083,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||||
|
|
||||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
valid_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||||
if args.enable_auto_tool_choice \
|
if args.enable_auto_tool_choice \
|
||||||
and args.tool_call_parser not in valide_tool_parses:
|
and args.tool_call_parser not in valid_tool_parses:
|
||||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
f"(chose from {{ {','.join(valid_tool_parses)} }})")
|
||||||
|
|
||||||
|
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
|
||||||
|
if args.enable_reasoning \
|
||||||
|
and args.reasoning_parser not in valid_reasoning_parses:
|
||||||
|
raise KeyError(
|
||||||
|
f"invalid reasoning parser: {args.reasoning_parser} "
|
||||||
|
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
|
||||||
|
|
||||||
# workaround to make sure that we bind the port before the engine is set up.
|
# workaround to make sure that we bind the port before the engine is set up.
|
||||||
# This avoids race conditions with ray.
|
# This avoids race conditions with ray.
|
||||||
|
|
@ -866,13 +1116,28 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
app = build_app(args)
|
app = build_app(args)
|
||||||
|
|
||||||
model_config = await engine_client.get_model_config()
|
model_config = await engine_client.get_model_config()
|
||||||
init_app_state(engine_client, model_config, app.state, args)
|
await init_app_state(engine_client, model_config, app.state, args)
|
||||||
|
|
||||||
|
def _listen_addr(a: str) -> str:
|
||||||
|
if is_valid_ipv6_address(a):
|
||||||
|
return '[' + a + ']'
|
||||||
|
return a or "0.0.0.0"
|
||||||
|
|
||||||
|
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||||
|
logger.info("Starting vLLM API server on http%s://%s:%d",
|
||||||
|
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||||
|
sock_addr[1])
|
||||||
|
|
||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
|
sock=sock,
|
||||||
|
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
log_level=args.uvicorn_log_level,
|
log_level=args.uvicorn_log_level,
|
||||||
|
# NOTE: When the 'disable_uvicorn_access_log' value is True,
|
||||||
|
# no access log will be output.
|
||||||
|
access_log=not args.disable_uvicorn_access_log,
|
||||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||||
ssl_keyfile=args.ssl_keyfile,
|
ssl_keyfile=args.ssl_keyfile,
|
||||||
ssl_certfile=args.ssl_certfile,
|
ssl_certfile=args.ssl_certfile,
|
||||||
|
|
@ -882,16 +1147,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
# NB: Await server shutdown only after the backend context is exited
|
# NB: Await server shutdown only after the backend context is exited
|
||||||
await shutdown_task
|
try:
|
||||||
|
await shutdown_task
|
||||||
sock.close()
|
finally:
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# NOTE(simon):
|
# NOTE(simon):
|
||||||
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
|
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
|
||||||
|
# entrypoints.
|
||||||
logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
|
logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
|
||||||
"instead of `vllm.entrypoints.openai.api_server` to start the API server")
|
"instead of `vllm.entrypoints.openai.api_server` to start the API server")
|
||||||
|
cli_env_setup()
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||||
parser = make_arg_parser(parser)
|
parser = make_arg_parser(parser)
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ def _sample_get_logits(
|
||||||
logits = lm_head(hidden_states)
|
logits = lm_head(hidden_states)
|
||||||
if embedding_bias is not None:
|
if embedding_bias is not None:
|
||||||
logits += embedding_bias
|
logits += embedding_bias
|
||||||
if self.use_gather:
|
if self.use_all_gather:
|
||||||
logits = tensor_model_parallel_gather(logits)
|
logits = tensor_model_parallel_gather(logits)
|
||||||
else:
|
else:
|
||||||
logits = tensor_model_parallel_all_gather(logits)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
|
@ -63,6 +63,8 @@ def _model_sample_convert():
|
||||||
|
|
||||||
|
|
||||||
def _ipex_llm_convert(load_in_low_bit):
|
def _ipex_llm_convert(load_in_low_bit):
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
from vllm.worker.xpu_model_runner import XPUModelRunner
|
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||||
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
|
from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
|
||||||
from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
|
from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
|
||||||
|
|
@ -99,7 +101,8 @@ def get_load_function(low_bit):
|
||||||
"codegeex4-all" in self.vllm_config.model_config.model.lower() or
|
"codegeex4-all" in self.vllm_config.model_config.model.lower() or
|
||||||
"chatglm" in self.vllm_config.model_config.model.lower()) and \
|
"chatglm" in self.vllm_config.model_config.model.lower()) and \
|
||||||
"gptq" not in self.model_config.model.lower() and \
|
"gptq" not in self.model_config.model.lower() and \
|
||||||
"awq" not in self.model_config.model.lower():
|
"awq" not in self.model_config.model.lower() and \
|
||||||
|
"qwen3" not in self.model_config.model.lower():
|
||||||
self.model.apply(padding_mlp)
|
self.model.apply(padding_mlp)
|
||||||
from ipex_llm import optimize_model
|
from ipex_llm import optimize_model
|
||||||
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue