vLLM: update vLLM XPU to 0.8.3 version (#13118)
vLLM: update vLLM XPU to 0.8.3 version
This commit is contained in:
		
							parent
							
								
									f66eee1d1d
								
							
						
					
					
						commit
						51b41faad7
					
				
					 11 changed files with 4608 additions and 28217 deletions
				
			
		| 
						 | 
				
			
			@ -1,7 +1,9 @@
 | 
			
		|||
From d345631f78a2f33ff1ddd7d9908b288eb0afaf46 Mon Sep 17 00:00:00 2001
 | 
			
		||||
From: Huajun Li <huajun.li@.com>
 | 
			
		||||
Date: Fri, 24 May 2024 09:47:26 +0800
 | 
			
		||||
Subject: [PATCH 1/3] allreduce optimization with LL256 for Arc770 dGPU
 | 
			
		||||
From dfe1851b59df6859829b447353307b7c916ccee0 Mon Sep 17 00:00:00 2001
 | 
			
		||||
From: junhansh <junhan.shi@intel.com>
 | 
			
		||||
Date: Mon, 28 Apr 2025 23:33:11 +0800
 | 
			
		||||
Subject: [PATCH] oneccl for Arc770 V2025.0.0.6.7
 | 
			
		||||
 | 
			
		||||
allreduce optimization with LL256 for Arc770 dGPU
 | 
			
		||||
 | 
			
		||||
To enable this feature, please set env var:
 | 
			
		||||
    export CCL_DG2_ALLREDUCE=1
 | 
			
		||||
| 
						 | 
				
			
			@ -12,6 +14,15 @@ Build:
 | 
			
		|||
    3. cmake .. -GNinja -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DCMAKE_CXX_FLAGS="-fsycl" -DCOMPUTE_BACKEND=dpcpp  -DCMAKE_BUILD_TYPE=MinSizeRel
 | 
			
		||||
    4. ninja
 | 
			
		||||
    5. ls -al src/libccl*
 | 
			
		||||
 | 
			
		||||
Changes:
 | 
			
		||||
optimize req_workgroup calculate
 | 
			
		||||
 | 
			
		||||
Revert "optimize req_workgroup calculate" for hang issue
 | 
			
		||||
 | 
			
		||||
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
 | 
			
		||||
 | 
			
		||||
fix_fdset_buffer_overflow_issue
 | 
			
		||||
---
 | 
			
		||||
 src/CMakeLists.txt               |   2 +
 | 
			
		||||
 src/coll/coll.cpp                |  30 +-
 | 
			
		||||
| 
						 | 
				
			
			@ -20,9 +31,9 @@ Build:
 | 
			
		|||
 src/common/env/env.cpp           |   1 +
 | 
			
		||||
 src/common/env/env.hpp           |   1 +
 | 
			
		||||
 src/common/env/vars.hpp          |   1 +
 | 
			
		||||
 src/dg2/dg2_allreduce.cpp        | 642 +++++++++++++++++++++++++++++++
 | 
			
		||||
 src/dg2/dg2_allreduce.cpp        | 640 +++++++++++++++++++++++++++++++
 | 
			
		||||
 src/dg2/dg2_allreduce.hpp        |  13 +
 | 
			
		||||
 9 files changed, 693 insertions(+), 3 deletions(-)
 | 
			
		||||
 9 files changed, 691 insertions(+), 3 deletions(-)
 | 
			
		||||
 create mode 100644 src/dg2/dg2_allreduce.cpp
 | 
			
		||||
 create mode 100644 src/dg2/dg2_allreduce.hpp
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -163,10 +174,10 @@ index 73dcf77..84ab518 100644
 | 
			
		|||
 constexpr const char* CCL_MIN_CHUNK_SIZE = "CCL_MIN_CHUNK_SIZE";
 | 
			
		||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
new file mode 100644
 | 
			
		||||
index 0000000..15ace74
 | 
			
		||||
index 0000000..73e114b
 | 
			
		||||
--- /dev/null
 | 
			
		||||
+++ b/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
@@ -0,0 +1,642 @@
 | 
			
		||||
@@ -0,0 +1,640 @@
 | 
			
		||||
+#include <fcntl.h>
 | 
			
		||||
+#include <unistd.h>
 | 
			
		||||
+#include <sys/un.h>
 | 
			
		||||
| 
						 | 
				
			
			@ -178,7 +189,7 @@ index 0000000..15ace74
 | 
			
		|||
+#include <drm/drm.h>
 | 
			
		||||
+
 | 
			
		||||
+#include <mpi.h>
 | 
			
		||||
+
 | 
			
		||||
+#include <poll.h>
 | 
			
		||||
+#include <vector>
 | 
			
		||||
+#include <sstream>
 | 
			
		||||
+#include <iostream>
 | 
			
		||||
| 
						 | 
				
			
			@ -315,7 +326,6 @@ index 0000000..15ace74
 | 
			
		|||
+
 | 
			
		||||
+static void *thread_func(void *arg)
 | 
			
		||||
+{
 | 
			
		||||
+    fd_set fds;
 | 
			
		||||
+    int count = 0;
 | 
			
		||||
+    char sock_path[64];
 | 
			
		||||
+    int peer_buf_fd = 0;
 | 
			
		||||
| 
						 | 
				
			
			@ -323,6 +333,10 @@ index 0000000..15ace74
 | 
			
		|||
+
 | 
			
		||||
+    snprintf(sock_path, sizeof(sock_path), "%s-%d_%d", SOCK_PATH, rank, 0xa770);
 | 
			
		||||
+    int srv_fd = srv_sock(sock_path);
 | 
			
		||||
+    if (srv_fd < 0) {
 | 
			
		||||
+         perror("srv_sock failed");
 | 
			
		||||
+	 return nullptr;
 | 
			
		||||
+    }
 | 
			
		||||
+
 | 
			
		||||
+    //std::cout << "-----> srv_fd of " << sock_path << " : " << srv_fd << "\n";
 | 
			
		||||
+
 | 
			
		||||
| 
						 | 
				
			
			@ -331,35 +345,30 @@ index 0000000..15ace74
 | 
			
		|||
+    ze_context_handle_t ze_context = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_context);
 | 
			
		||||
+    ze_device_handle_t  ze_device = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(sycl_device);
 | 
			
		||||
+
 | 
			
		||||
+    FD_ZERO(&fds);
 | 
			
		||||
+    FD_SET(srv_fd, &fds);
 | 
			
		||||
+    struct pollfd pfd = {
 | 
			
		||||
+            .fd = srv_fd,
 | 
			
		||||
+            .events = POLL_IN,
 | 
			
		||||
+            .revents = 0
 | 
			
		||||
+    };
 | 
			
		||||
+    while (++count < world_size) {
 | 
			
		||||
+        int ret = select(srv_fd + 1, &fds, NULL, NULL, NULL);
 | 
			
		||||
+        switch (ret) {
 | 
			
		||||
+        case 1:
 | 
			
		||||
+            {
 | 
			
		||||
+                int peer_rank;
 | 
			
		||||
+                void *peer_buf;
 | 
			
		||||
+        int ret = poll(&pfd, 1, -1);
 | 
			
		||||
+        if (ret <= 0) {
 | 
			
		||||
+	   std::cerr << "poll failed: " << strerror(errno) << "\n";
 | 
			
		||||
+	   break;
 | 
			
		||||
+        }
 | 
			
		||||
+
 | 
			
		||||
+                int conn_fd = accept(srv_fd, NULL, 0);
 | 
			
		||||
+                ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
 | 
			
		||||
+        if (pfd.revents & POLL_IN) {
 | 
			
		||||
+           int peer_rank;
 | 
			
		||||
+	   void *peer_buf = nullptr;
 | 
			
		||||
+
 | 
			
		||||
+                ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
 | 
			
		||||
+                zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED /* cached allocation */, &peer_buf);
 | 
			
		||||
+           int conn_fd = accept(srv_fd, NULL, 0);
 | 
			
		||||
+           ccl::utils::recvmsg_fd(conn_fd, &peer_buf_fd, &peer_rank, sizeof(peer_rank));
 | 
			
		||||
+           ze_ipc_mem_handle_t ipc_handle_peer_buf = get_handle_from_fd(peer_buf_fd);
 | 
			
		||||
+           zeMemOpenIpcHandle(ze_context, ze_device, ipc_handle_peer_buf, ZE_IPC_MEMORY_FLAG_BIAS_CACHED, &peer_buf);
 | 
			
		||||
+
 | 
			
		||||
+                peer_bufs[peer_rank] = peer_buf;
 | 
			
		||||
+                //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
 | 
			
		||||
+
 | 
			
		||||
+                if (conn_fd > 0) close(conn_fd);
 | 
			
		||||
+
 | 
			
		||||
+                break;
 | 
			
		||||
+            }
 | 
			
		||||
+        case 0:
 | 
			
		||||
+        case -1:
 | 
			
		||||
+            std::cout << "srv_fd select() failed" << "\n";
 | 
			
		||||
+            break;
 | 
			
		||||
+        default:
 | 
			
		||||
+            break;
 | 
			
		||||
+           peer_bufs[peer_rank] = peer_buf;
 | 
			
		||||
+           //printf("<------------- rank: %d, peer_bufs[%d]: %p\n", world_rank, peer_rank, peer_bufs[peer_rank]);
 | 
			
		||||
+           if (conn_fd > 0) close(conn_fd);
 | 
			
		||||
+        }
 | 
			
		||||
+    }
 | 
			
		||||
+
 | 
			
		||||
| 
						 | 
				
			
			@ -831,105 +840,3 @@ index 0000000..0506445
 | 
			
		|||
-- 
 | 
			
		||||
2.34.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
From 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa Mon Sep 17 00:00:00 2001
 | 
			
		||||
From: Huajun Li <huajun.li@.com>
 | 
			
		||||
Date: Fri, 7 Mar 2025 15:15:35 +0800
 | 
			
		||||
Subject: [PATCH 2/3] optimize req_workgroup calculate
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 src/dg2/dg2_allreduce.cpp | 25 ++-----------------------
 | 
			
		||||
 1 file changed, 2 insertions(+), 23 deletions(-)
 | 
			
		||||
 | 
			
		||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
index 15ace74..83270ae 100644
 | 
			
		||||
--- a/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
+++ b/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
@@ -527,30 +527,9 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
 | 
			
		||||
                 auto chunk_sz = req_workitems * LS_SZ;         /* LS_SZ bytes per work-item */
 | 
			
		||||
                 auto chunk_with_pattern = sg_sz * LS_SZ;       /* aligned to 256B */
 | 
			
		||||
 
 | 
			
		||||
-                /* items will be assigned to each rank */
 | 
			
		||||
-                auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
 | 
			
		||||
-                auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
 | 
			
		||||
-                auto req_subgroups = 0;
 | 
			
		||||
-
 | 
			
		||||
-                if (req_workgroups >= g_sz/l_sz) {
 | 
			
		||||
-                    req_workgroups = g_sz/l_sz;
 | 
			
		||||
-                } else {
 | 
			
		||||
-                    if (group_id == (req_workgroups - 1)) {
 | 
			
		||||
-                        req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
 | 
			
		||||
-
 | 
			
		||||
-                        /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
 | 
			
		||||
-                        /* Note:  req_subgroups % (l_sz/sg_sz) might be 0 */
 | 
			
		||||
-                        if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
 | 
			
		||||
-                            if ((per_rank_items % (sg_sz - 1)) != 0) {
 | 
			
		||||
-                                /* FIXME: */
 | 
			
		||||
-                                req_workitems = per_rank_items % (sg_sz - 1);
 | 
			
		||||
-                                chunk_sz = req_workitems * LS_SZ;    /* LS_SZ bytes per work-item */
 | 
			
		||||
-                            }
 | 
			
		||||
-                        }
 | 
			
		||||
-                    }
 | 
			
		||||
-                }
 | 
			
		||||
+		auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
 | 
			
		||||
 
 | 
			
		||||
-                if (group_id < req_workgroups) {
 | 
			
		||||
+                if (work_left > 0) {
 | 
			
		||||
                     // step 1: push data to next GPU
 | 
			
		||||
                     {
 | 
			
		||||
                         offset = base + local_world_rank * chunk_sz;
 | 
			
		||||
-- 
 | 
			
		||||
2.34.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
From 1c58cc9ede5ca38138a270f9e5ff59bca74f51d4 Mon Sep 17 00:00:00 2001
 | 
			
		||||
From: Huajun Li <huajun.li@.com>
 | 
			
		||||
Date: Wed, 12 Mar 2025 13:29:27 +0800
 | 
			
		||||
Subject: [PATCH 3/3] Revert "optimize req_workgroup calculate" for hang issue
 | 
			
		||||
 | 
			
		||||
This reverts commit 20bfd0e0a37f93dfb8bb9c092cd5a0b35e868bfa.
 | 
			
		||||
---
 | 
			
		||||
 src/dg2/dg2_allreduce.cpp | 25 +++++++++++++++++++++++--
 | 
			
		||||
 1 file changed, 23 insertions(+), 2 deletions(-)
 | 
			
		||||
 | 
			
		||||
diff --git a/src/dg2/dg2_allreduce.cpp b/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
index 83270ae..15ace74 100644
 | 
			
		||||
--- a/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
+++ b/src/dg2/dg2_allreduce.cpp
 | 
			
		||||
@@ -527,9 +527,30 @@ ccl::event dg2_ll256_allreduce(const void *src, void *dst, size_t count,
 | 
			
		||||
                 auto chunk_sz = req_workitems * LS_SZ;         /* LS_SZ bytes per work-item */
 | 
			
		||||
                 auto chunk_with_pattern = sg_sz * LS_SZ;       /* aligned to 256B */
 | 
			
		||||
 
 | 
			
		||||
-		auto work_left = unreduced - sg_id * local_world_size * chunk_sz;
 | 
			
		||||
+                /* items will be assigned to each rank */
 | 
			
		||||
+                auto per_rank_items = (unreduced + (local_world_size * LS_SZ - 1)) / (local_world_size * LS_SZ);
 | 
			
		||||
+                auto req_workgroups = (per_rank_items + (workgroup_available_items - 1)) / workgroup_available_items;
 | 
			
		||||
+                auto req_subgroups = 0;
 | 
			
		||||
+
 | 
			
		||||
+                if (req_workgroups >= g_sz/l_sz) {
 | 
			
		||||
+                    req_workgroups = g_sz/l_sz;
 | 
			
		||||
+                } else {
 | 
			
		||||
+                    if (group_id == (req_workgroups - 1)) {
 | 
			
		||||
+                        req_subgroups = (per_rank_items + (sg_sz - 1)) / (sg_sz - 1);
 | 
			
		||||
+
 | 
			
		||||
+                        /* (req_subgroups % (l_sz/sg_sz) - 1) equals to the final subgroup id in a workgroup */
 | 
			
		||||
+                        /* Note:  req_subgroups % (l_sz/sg_sz) might be 0 */
 | 
			
		||||
+                        if (((req_subgroups % (l_sz/sg_sz)) == 0) || (sg_id == (req_subgroups % (l_sz/sg_sz) - 1))) {
 | 
			
		||||
+                            if ((per_rank_items % (sg_sz - 1)) != 0) {
 | 
			
		||||
+                                /* FIXME: */
 | 
			
		||||
+                                req_workitems = per_rank_items % (sg_sz - 1);
 | 
			
		||||
+                                chunk_sz = req_workitems * LS_SZ;    /* LS_SZ bytes per work-item */
 | 
			
		||||
+                            }
 | 
			
		||||
+                        }
 | 
			
		||||
+                    }
 | 
			
		||||
+                }
 | 
			
		||||
 
 | 
			
		||||
-                if (work_left > 0) {
 | 
			
		||||
+                if (group_id < req_workgroups) {
 | 
			
		||||
                     // step 1: push data to next GPU
 | 
			
		||||
                     {
 | 
			
		||||
                         offset = base + local_world_rank * chunk_sz;
 | 
			
		||||
-- 
 | 
			
		||||
2.34.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -54,19 +54,20 @@ RUN set -eux && \
 | 
			
		|||
    #
 | 
			
		||||
    # Install Intel PyTorch extension for LLM inference
 | 
			
		||||
    pip install --pre --upgrade ipex-llm[xpu_2.6] --extra-index-url https://download.pytorch.org/whl/xpu && \
 | 
			
		||||
    pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
 | 
			
		||||
    #
 | 
			
		||||
    # Build torch-ccl
 | 
			
		||||
    mkdir -p /build && \
 | 
			
		||||
    cd /build && \
 | 
			
		||||
    git clone https://github.com/intel/torch-ccl.git && \
 | 
			
		||||
    cd torch-ccl && \
 | 
			
		||||
    git checkout ccl_torch2.5.0+xpu && \
 | 
			
		||||
    git checkout ccl_torch2.6.0+xpu && \
 | 
			
		||||
    git submodule sync && \
 | 
			
		||||
    git submodule update --init --recursive && \
 | 
			
		||||
    # This patch will enable build torch-ccl with pytorch 2.6 environment
 | 
			
		||||
    git apply /tmp/ccl_torch.patch && \
 | 
			
		||||
    # git apply /tmp/ccl_torch.patch && \
 | 
			
		||||
    USE_SYSTEM_ONECCL=ON COMPUTE_BACKEND=dpcpp python setup.py bdist_wheel && \
 | 
			
		||||
    # File path: /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl
 | 
			
		||||
    # File path: /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl
 | 
			
		||||
    # Build oneCCL
 | 
			
		||||
    pip install ninja && \
 | 
			
		||||
    cd /build/ && \
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +86,7 @@ RUN set -eux && \
 | 
			
		|||
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04
 | 
			
		||||
 | 
			
		||||
# Copy the built torch-ccl package from the build stage
 | 
			
		||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
 | 
			
		||||
COPY --from=build /build/torch-ccl/dist/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl /opt/
 | 
			
		||||
COPY --from=build /llm/ /llm/
 | 
			
		||||
COPY --from=build /build/oneCCL/build/src/libccl.so.1.0 /opt/intel/1ccl-wks/lib/
 | 
			
		||||
COPY --from=build /build/oneCCL/build/src/libccl.so.1 /opt/intel/1ccl-wks/lib/
 | 
			
		||||
| 
						 | 
				
			
			@ -144,9 +145,10 @@ RUN set -eux && \
 | 
			
		|||
    # Install vllm dependencies
 | 
			
		||||
    pip install --upgrade fastapi && \
 | 
			
		||||
    pip install --upgrade "uvicorn[standard]" && \
 | 
			
		||||
    pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
 | 
			
		||||
    #
 | 
			
		||||
    # Install torch-ccl
 | 
			
		||||
    pip install /opt/oneccl_bind_pt-2.5.0+xpu-cp311-cp311-linux_x86_64.whl && \
 | 
			
		||||
    pip install /opt/oneccl_bind_pt-2.6.0+xpu-cp311-cp311-linux_x86_64.whl && \
 | 
			
		||||
    #
 | 
			
		||||
    apt-get update && \
 | 
			
		||||
    apt-get install -y --no-install-recommends libfabric-dev wrk libaio-dev numactl && \
 | 
			
		||||
| 
						 | 
				
			
			@ -168,21 +170,19 @@ RUN set -eux && \
 | 
			
		|||
    mkdir -p /llm && \
 | 
			
		||||
    cd /llm && \
 | 
			
		||||
    rm -rf /tmp/neo && \
 | 
			
		||||
    # Install intel_extension_for_pytorch
 | 
			
		||||
    pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \
 | 
			
		||||
    pip uninstall -y oneccl oneccl-devel && \
 | 
			
		||||
    pip install intel-opencl-rt==2025.0.2 intel-openmp==2025.0.2 && \
 | 
			
		||||
    #
 | 
			
		||||
    # Install vllm
 | 
			
		||||
    git clone -b v0.6.6.post1 https://github.com/vllm-project/vllm /llm/vllm && \
 | 
			
		||||
    git clone -b v0.8.3 https://github.com/vllm-project/vllm /llm/vllm && \
 | 
			
		||||
    cd /llm/vllm && \
 | 
			
		||||
    git apply /llm/vllm_for_multi_arc.patch && \
 | 
			
		||||
    pip install setuptools-scm && \
 | 
			
		||||
    pip install setuptools-scm==8.2.0 setuptools==78.1.0 && \
 | 
			
		||||
    pip install --upgrade cmake && \
 | 
			
		||||
    VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm && \
 | 
			
		||||
    pip install -v -r requirements/xpu.txt && \
 | 
			
		||||
    VLLM_TARGET_DEVICE=xpu python setup.py install && \
 | 
			
		||||
    pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ && \
 | 
			
		||||
    pip uninstall -y oneccl oneccl-devel && \
 | 
			
		||||
    rm -rf /llm/vllm_for_multi_arc.patch && \
 | 
			
		||||
    pip install mpi4py fastapi uvicorn openai && \
 | 
			
		||||
    pip install ray
 | 
			
		||||
    pip install ray numba
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
WORKDIR /llm/
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,6 +32,9 @@ export TORCH_LLM_ALLREDUCE=0
 | 
			
		|||
export CCL_SAME_STREAM=1
 | 
			
		||||
export CCL_BLOCKING_WAIT=0
 | 
			
		||||
 | 
			
		||||
export VLLM_USE_V1=0
 | 
			
		||||
export IPEX_LLM_LOWBIT=$LOAD_IN_LOW_BIT
 | 
			
		||||
 | 
			
		||||
source /opt/intel/1ccl-wks/setvars.sh
 | 
			
		||||
 | 
			
		||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| 
						 | 
				
			
			@ -782,6 +782,9 @@ export USE_XETLA=OFF
 | 
			
		|||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
			
		||||
export TORCH_LLM_ALLREDUCE=0
 | 
			
		||||
 | 
			
		||||
export VLLM_USE_V1=0
 | 
			
		||||
export IPEX_LLM_LOWBIT=fp8
 | 
			
		||||
 | 
			
		||||
source /opt/intel/1ccl-wks/setvars.sh
 | 
			
		||||
 | 
			
		||||
python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
			
		||||
| 
						 | 
				
			
			@ -793,7 +796,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
			
		|||
  --device xpu \
 | 
			
		||||
  --dtype float16 \
 | 
			
		||||
  --enforce-eager \
 | 
			
		||||
  --load-in-low-bit fp8 \
 | 
			
		||||
  --load-in-low-bit $IPEX_LLM_LOWBIT \
 | 
			
		||||
  --max-model-len 2048 \
 | 
			
		||||
  --max-num-batched-tokens 4000 \
 | 
			
		||||
  --api-key <your-api-key> \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,9 +50,14 @@ pip install --pre --upgrade "ipex-llm[xpu_2.6]" --extra-index-url https://pytorc
 | 
			
		|||
pip install setuptools-scm
 | 
			
		||||
pip install --upgrade cmake
 | 
			
		||||
# cd to your workdir
 | 
			
		||||
git clone -b 0.6.6 https://github.com/analytics-zoo/vllm.git
 | 
			
		||||
git clone -b 0.8.3 https://github.com/analytics-zoo/vllm.git
 | 
			
		||||
cd vllm
 | 
			
		||||
VLLM_TARGET_DEVICE=xpu pip install --no-build-isolation -v /llm/vllm
 | 
			
		||||
pip install setuptools-scm==8.2.0 setuptools==78.1.0
 | 
			
		||||
pip install --upgrade cmake
 | 
			
		||||
pip install -v -r requirements/xpu.txt
 | 
			
		||||
VLLM_TARGET_DEVICE=xpu python setup.py install
 | 
			
		||||
pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
 | 
			
		||||
pip uninstall -y oneccl oneccl-devel
 | 
			
		||||
# For Qwen model support
 | 
			
		||||
pip install transformers_stream_generator einops tiktoken
 | 
			
		||||
pip install ray
 | 
			
		||||
| 
						 | 
				
			
			@ -93,6 +98,8 @@ For vLLM, you can start the service using the following command:
 | 
			
		|||
model="YOUR_MODEL_PATH"
 | 
			
		||||
served_model_name="YOUR_MODEL_NAME"
 | 
			
		||||
export VLLM_RPC_TIMEOUT=100000
 | 
			
		||||
export VLLM_USE_V1=0
 | 
			
		||||
export IPEX_LLM_LOWBIT=fp8
 | 
			
		||||
 | 
			
		||||
 # You may need to adjust the value of
 | 
			
		||||
 # --max-model-len, --max-num-batched-tokens, --max-num-seqs
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +114,7 @@ python -m ipex_llm.vllm.xpu.entrypoints.openai.api_server \
 | 
			
		|||
  --device xpu \
 | 
			
		||||
  --dtype float16 \
 | 
			
		||||
  --enforce-eager \
 | 
			
		||||
  --load-in-low-bit sym_int4 \
 | 
			
		||||
  --load-in-low-bit $IPEX_LLM_LOWBIT \
 | 
			
		||||
  --max-model-len 4096 \
 | 
			
		||||
  --max-num-batched-tokens 10240 \
 | 
			
		||||
  --max-num-seqs 12 \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -150,12 +150,13 @@ def is_linear_module(module):
 | 
			
		|||
        if _VLLM_VERSION is None:
 | 
			
		||||
            _VLLM_VERSION = get_package_version('vllm')
 | 
			
		||||
        from vllm.model_executor.layers.linear import (
 | 
			
		||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear, MergedColumnParallelLinear
 | 
			
		||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
 | 
			
		||||
            MergedColumnParallelLinear, ReplicatedLinear
 | 
			
		||||
        )
 | 
			
		||||
        from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
 | 
			
		||||
        VLLM_LINEAR_LIST = [
 | 
			
		||||
            ColumnParallelLinear, RowParallelLinear, QKVParallelLinear,
 | 
			
		||||
            MergedColumnParallelLinear,
 | 
			
		||||
            MergedColumnParallelLinear, ReplicatedLinear,
 | 
			
		||||
        ]
 | 
			
		||||
        if 'xpu' in _VLLM_VERSION:
 | 
			
		||||
            VLLM_LINEAR_LIST.append(ParallelLMHead)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,10 +13,12 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine
 | 
			
		||||
from .engine import IPEXLLMAsyncLLMEngine, IPEXLLMLLMEngine, IPEXLLMClass, run_mp_engine, IPEXLLMAsyncV1Engine, IPEXLLMLLMV1Engine
 | 
			
		||||
__all__ = [
 | 
			
		||||
    "IPEXLLMAsyncLLMEngine",
 | 
			
		||||
    "IPEXLLMLLMEngine",
 | 
			
		||||
    "IPEXLLMClass",
 | 
			
		||||
    "IPEXLLMAsyncV1Engine",
 | 
			
		||||
    "IPEXLLMLLMV1Engine",
 | 
			
		||||
    "run_mp_engine",
 | 
			
		||||
]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,6 +38,8 @@ logger = init_logger(__name__)
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
			
		||||
    _is_converted = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -53,13 +55,39 @@ class IPEXLLMAsyncLLMEngine(AsyncLLMEngine):
 | 
			
		|||
    ) -> "AsyncLLMEngine":
 | 
			
		||||
        """Creates an async LLM engine from the engine arguments."""
 | 
			
		||||
        # Create the engine configs.
 | 
			
		||||
        _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
 | 
			
		||||
                                        start_engine_loop=start_engine_loop,
 | 
			
		||||
                                        usage_context=usage_context, stat_loggers=stat_loggers)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_vllm_config(
 | 
			
		||||
        cls,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        start_engine_loop: bool = True,
 | 
			
		||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
			
		||||
        stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
 | 
			
		||||
        disable_log_requests: bool = False,
 | 
			
		||||
        disable_log_stats: bool = False,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
    ) -> "AsyncLLMEngine":
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_vllm_config(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            start_engine_loop=start_engine_loop,
 | 
			
		||||
            usage_context=usage_context,
 | 
			
		||||
            stat_loggers=stat_loggers,
 | 
			
		||||
            disable_log_requests=disable_log_requests,
 | 
			
		||||
            disable_log_stats=disable_log_stats,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IPEXLLMAsyncV1Engine(AsyncLLM):
 | 
			
		||||
    _is_converted = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
| 
						 | 
				
			
			@ -74,13 +102,39 @@ class IPEXLLMAsyncV1Engine(AsyncLLM):
 | 
			
		|||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
        stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,  # noqa
 | 
			
		||||
    ) -> "AsyncLLM":
 | 
			
		||||
        _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_engine_args(engine_args=engine_args, engine_config=engine_config,
 | 
			
		||||
                                        start_engine_loop=start_engine_loop,
 | 
			
		||||
                                        usage_context=usage_context, stat_loggers=stat_loggers)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_vllm_config(
 | 
			
		||||
        cls,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        start_engine_loop: bool = True,
 | 
			
		||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
			
		||||
        stat_loggers: Optional[dict[str, StatLoggerBase]]=None,
 | 
			
		||||
        disable_log_requests: bool = False,
 | 
			
		||||
        disable_log_stats: bool = False,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
    ) -> "AsyncLLM":
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_vllm_config(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            start_engine_loop=start_engine_loop,
 | 
			
		||||
            usage_context=usage_context,
 | 
			
		||||
            stat_loggers=stat_loggers,
 | 
			
		||||
            disable_log_requests=disable_log_requests,
 | 
			
		||||
            disable_log_stats=disable_log_stats,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IPEXLLMClass(LLM):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model: str,
 | 
			
		||||
| 
						 | 
				
			
			@ -94,20 +148,20 @@ class IPEXLLMClass(LLM):
 | 
			
		|||
        quantization: Optional[str] = None,
 | 
			
		||||
        revision: Optional[str] = None,
 | 
			
		||||
        tokenizer_revision: Optional[str] = None,
 | 
			
		||||
        seed: int = 0,
 | 
			
		||||
        seed: Optional[int] = None,
 | 
			
		||||
        gpu_memory_utilization: float = 0.9,
 | 
			
		||||
        swap_space: float = 4,
 | 
			
		||||
        cpu_offload_gb: float = 0,
 | 
			
		||||
        enforce_eager: Optional[bool] = None,
 | 
			
		||||
        max_seq_len_to_capture: int = 8192,
 | 
			
		||||
        disable_custom_all_reduce: bool = False,
 | 
			
		||||
        disable_async_output_proc: bool = True,
 | 
			
		||||
        hf_overrides: Optional[HfOverrides] = None,
 | 
			
		||||
        mm_processor_kwargs: Optional[Dict[str, Any]]=None,
 | 
			
		||||
        disable_async_output_proc: bool = False,
 | 
			
		||||
        hf_overrides: Optional[HfOverrides]=None,
 | 
			
		||||
        mm_processor_kwargs: Optional[dict[str, Any]]=None,
 | 
			
		||||
        # After positional args are removed, move this right below `model`
 | 
			
		||||
        task: TaskOption = "auto",
 | 
			
		||||
        override_pooler_config: Optional[PoolerConfig] = None,
 | 
			
		||||
        compilation_config: Optional[Union[int, Dict[str, Any]]]=None,
 | 
			
		||||
        compilation_config: Optional[Union[int, dict[str, Any]]]=None,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -120,6 +174,13 @@ class IPEXLLMClass(LLM):
 | 
			
		|||
        if "disable_log_stats" not in kwargs:
 | 
			
		||||
            kwargs["disable_log_stats"] = True
 | 
			
		||||
 | 
			
		||||
        if "worker_cls" in kwargs:
 | 
			
		||||
            worker_cls = kwargs["worker_cls"]
 | 
			
		||||
            # if the worker_cls is not qualified string name,
 | 
			
		||||
            # we serialize it using cloudpickle to avoid pickling issues
 | 
			
		||||
            if isinstance(worker_cls, type):
 | 
			
		||||
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
 | 
			
		||||
 | 
			
		||||
        if compilation_config is not None:
 | 
			
		||||
            if isinstance(compilation_config, (int, dict)):
 | 
			
		||||
                compilation_config_instance = CompilationConfig.from_cli(
 | 
			
		||||
| 
						 | 
				
			
			@ -159,11 +220,13 @@ class IPEXLLMClass(LLM):
 | 
			
		|||
        # Logic to switch between engines is done at runtime instead of import
 | 
			
		||||
        # to avoid import order issues
 | 
			
		||||
        self.engine_class = self.get_engine_class()
 | 
			
		||||
        # print("!!! ", load_in_low_bit)
 | 
			
		||||
        self.llm_engine = self.engine_class.from_engine_args(
 | 
			
		||||
            engine_args, usage_context=UsageContext.LLM_CLASS,
 | 
			
		||||
            load_in_low_bit=load_in_low_bit)
 | 
			
		||||
 | 
			
		||||
        self.request_counter = Counter()
 | 
			
		||||
        self.default_sampling_params: Union[dict[str, Any], None] = None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_engine_class() -> Type[LLMEngine]:
 | 
			
		||||
| 
						 | 
				
			
			@ -173,6 +236,8 @@ class IPEXLLMClass(LLM):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class IPEXLLMLLMV1Engine(V1LLMEngine):
 | 
			
		||||
    _is_converted = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -188,14 +253,37 @@ class IPEXLLMLLMV1Engine(V1LLMEngine):
 | 
			
		|||
        """Creates an LLM engine from the engine arguments."""
 | 
			
		||||
        # Create the engine configs.
 | 
			
		||||
 | 
			
		||||
        _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_engine_args(engine_args,
 | 
			
		||||
                                        usage_context,
 | 
			
		||||
                                        stat_loggers,
 | 
			
		||||
                                        enable_multiprocessing)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_vllm_config(
 | 
			
		||||
        cls,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
			
		||||
        stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
 | 
			
		||||
        disable_log_stats: bool = False,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
    ) -> "LLMEngine":
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_vllm_config(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            usage_context=usage_context,
 | 
			
		||||
            stat_loggers=stat_loggers,
 | 
			
		||||
            disable_log_stats=disable_log_stats
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IPEXLLMLLMEngine(LLMEngine):
 | 
			
		||||
    _is_converted = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -209,33 +297,89 @@ class IPEXLLMLLMEngine(LLMEngine):
 | 
			
		|||
    ) -> "LLMEngine":
 | 
			
		||||
        """Creates an LLM engine from the engine arguments."""
 | 
			
		||||
        # Create the engine configs.
 | 
			
		||||
        _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_engine_args(engine_args, usage_context, stat_loggers)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_vllm_config(
 | 
			
		||||
        cls,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
 | 
			
		||||
        stat_loggers: Optional[Dict[str, StatLoggerBase]]=None,
 | 
			
		||||
        disable_log_stats: bool = False,
 | 
			
		||||
        load_in_low_bit: str = "sym_int4",
 | 
			
		||||
    ) -> "LLMEngine":
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_vllm_config(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            usage_context=usage_context,
 | 
			
		||||
            stat_loggers=stat_loggers,
 | 
			
		||||
            disable_log_stats=disable_log_stats
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IPEXLLMMQLLMEngine(MQLLMEngine):
 | 
			
		||||
    _is_converted = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_engine_args(cls, engine_args: AsyncEngineArgs,
 | 
			
		||||
                         usage_context: UsageContext, ipc_path: str, load_in_low_bit: str):
 | 
			
		||||
        _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_engine_args(engine_args, usage_context, ipc_path)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_vllm_config(cls, vllm_config: VllmConfig,
 | 
			
		||||
                         usage_context: UsageContext,
 | 
			
		||||
                         disable_log_requests: bool, disable_log_stats: bool,
 | 
			
		||||
                         ipc_path: str, load_in_low_bit: str) -> "MQLLMEngine":
 | 
			
		||||
 | 
			
		||||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
 | 
			
		||||
                  ipc_path: str, load_in_low_bit: str, engine_alive):
 | 
			
		||||
        if not cls._is_converted:
 | 
			
		||||
            _ipex_llm_convert(load_in_low_bit)
 | 
			
		||||
            cls._is_converted = True
 | 
			
		||||
        return super().from_vllm_config(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            ipc_path=ipc_path,
 | 
			
		||||
            usage_context=usage_context,
 | 
			
		||||
            disable_log_requests=disable_log_requests,
 | 
			
		||||
            disable_log_stats=disable_log_stats,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def signal_handler(*_) -> None:
 | 
			
		||||
        # Interrupt server on sigterm
 | 
			
		||||
        raise KeyboardInterrupt("MQLLMEngine terminated")  # noqa
 | 
			
		||||
from vllm.transformers_utils.config import (
 | 
			
		||||
    maybe_register_config_serialize_by_value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def signal_handler(*_) -> None:
 | 
			
		||||
    raise KeyboardInterrupt("MQLLMEngine terminated")  # noqa
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext,
 | 
			
		||||
                  ipc_path: str, disable_log_stats: bool,
 | 
			
		||||
                  disable_log_requests: bool, load_in_low_bit, engine_alive):
 | 
			
		||||
    try:
 | 
			
		||||
        # Ensure we can serialize transformer config before spawning
 | 
			
		||||
        maybe_register_config_serialize_by_value()
 | 
			
		||||
 | 
			
		||||
        engine = IPEXLLMMQLLMEngine.from_vllm_config(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            usage_context=usage_context,
 | 
			
		||||
            disable_log_stats=disable_log_stats,
 | 
			
		||||
            disable_log_requests=disable_log_requests,
 | 
			
		||||
            load_in_low_bit=load_in_low_bit,
 | 
			
		||||
            ipc_path=ipc_path)
 | 
			
		||||
 | 
			
		||||
        signal.signal(signal.SIGTERM, signal_handler)
 | 
			
		||||
 | 
			
		||||
        engine = IPEXLLMMQLLMEngine.from_engine_args(engine_args=engine_args,
 | 
			
		||||
                                                     usage_context=usage_context,
 | 
			
		||||
                                                     ipc_path=ipc_path,
 | 
			
		||||
                                                     load_in_low_bit=load_in_low_bit)
 | 
			
		||||
        engine.start()
 | 
			
		||||
 | 
			
		||||
    except BaseException as e:
 | 
			
		||||
        logger.exception(e)
 | 
			
		||||
        engine_alive.value = False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,8 @@
 | 
			
		|||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import atexit
 | 
			
		||||
import gc
 | 
			
		||||
import importlib
 | 
			
		||||
import inspect
 | 
			
		||||
import multiprocessing
 | 
			
		||||
| 
						 | 
				
			
			@ -10,16 +13,18 @@ import socket
 | 
			
		|||
import tempfile
 | 
			
		||||
import uuid
 | 
			
		||||
from argparse import Namespace
 | 
			
		||||
from collections.abc import AsyncIterator
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from functools import partial
 | 
			
		||||
from http import HTTPStatus
 | 
			
		||||
from typing import AsyncIterator, Optional, Set, Tuple
 | 
			
		||||
from typing import Annotated, Optional, Union
 | 
			
		||||
 | 
			
		||||
import uvloop
 | 
			
		||||
from fastapi import APIRouter, FastAPI, Request
 | 
			
		||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
 | 
			
		||||
from fastapi.exceptions import RequestValidationError
 | 
			
		||||
from fastapi.middleware.cors import CORSMiddleware
 | 
			
		||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
 | 
			
		||||
from starlette.concurrency import iterate_in_threadpool
 | 
			
		||||
from starlette.datastructures import State
 | 
			
		||||
from starlette.routing import Mount
 | 
			
		||||
from typing_extensions import assert_never
 | 
			
		||||
| 
						 | 
				
			
			@ -27,17 +32,17 @@ from typing_extensions import assert_never
 | 
			
		|||
import vllm.envs as envs
 | 
			
		||||
from vllm.config import ModelConfig
 | 
			
		||||
from vllm.engine.arg_utils import AsyncEngineArgs
 | 
			
		||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine
 | 
			
		||||
from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncLLMEngine as AsyncLLMEngine  # type: ignore
 | 
			
		||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
 | 
			
		||||
from ipex_llm.vllm.xpu.engine import run_mp_engine
 | 
			
		||||
from vllm.engine.protocol import EngineClient
 | 
			
		||||
from vllm.entrypoints.chat_utils import load_chat_template
 | 
			
		||||
from vllm.entrypoints.chat_utils import (load_chat_template,
 | 
			
		||||
                                         resolve_hf_chat_template,
 | 
			
		||||
                                         resolve_mistral_chat_template)
 | 
			
		||||
from vllm.entrypoints.launcher import serve_http
 | 
			
		||||
from vllm.entrypoints.logger import RequestLogger
 | 
			
		||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
 | 
			
		||||
                                              validate_parsed_serve_args)
 | 
			
		||||
 | 
			
		||||
# from ipex_llm.vllm.xpu.entrypoints.openai.cli_args import make_arg_parser
 | 
			
		||||
# yapf conflicts with isort for this block
 | 
			
		||||
# yapf: disable
 | 
			
		||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
 | 
			
		||||
| 
						 | 
				
			
			@ -46,33 +51,46 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
 | 
			
		|||
                                              CompletionResponse,
 | 
			
		||||
                                              DetokenizeRequest,
 | 
			
		||||
                                              DetokenizeResponse,
 | 
			
		||||
                                              EmbeddingChatRequest,
 | 
			
		||||
                                              EmbeddingCompletionRequest,
 | 
			
		||||
                                              EmbeddingRequest,
 | 
			
		||||
                                              EmbeddingResponse,
 | 
			
		||||
                                              EmbeddingResponseData,
 | 
			
		||||
                                              ErrorResponse,
 | 
			
		||||
                                              LoadLoraAdapterRequest,
 | 
			
		||||
                                              LoadLoRAAdapterRequest,
 | 
			
		||||
                                              PoolingChatRequest,
 | 
			
		||||
                                              PoolingCompletionRequest,
 | 
			
		||||
                                              PoolingRequest, PoolingResponse,
 | 
			
		||||
                                              RerankRequest, RerankResponse,
 | 
			
		||||
                                              ScoreRequest, ScoreResponse,
 | 
			
		||||
                                              TokenizeRequest,
 | 
			
		||||
                                              TokenizeResponse,
 | 
			
		||||
                                              UnloadLoraAdapterRequest)
 | 
			
		||||
                                              TranscriptionRequest,
 | 
			
		||||
                                              TranscriptionResponse,
 | 
			
		||||
                                              UnloadLoRAAdapterRequest)
 | 
			
		||||
# yapf: enable
 | 
			
		||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
 | 
			
		||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
 | 
			
		||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
 | 
			
		||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
 | 
			
		||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
 | 
			
		||||
                                                    OpenAIServingModels)
 | 
			
		||||
 | 
			
		||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
 | 
			
		||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
 | 
			
		||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
 | 
			
		||||
from vllm.entrypoints.openai.serving_score import ServingScores
 | 
			
		||||
from vllm.entrypoints.openai.serving_tokenization import (
 | 
			
		||||
    OpenAIServingTokenization)
 | 
			
		||||
from vllm.entrypoints.openai.serving_transcription import (
 | 
			
		||||
    OpenAIServingTranscription)
 | 
			
		||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
 | 
			
		||||
from vllm.entrypoints.utils import with_cancellation
 | 
			
		||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
 | 
			
		||||
                                    with_cancellation)
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.reasoning import ReasoningParserManager
 | 
			
		||||
from vllm.transformers_utils.config import (
 | 
			
		||||
    maybe_register_config_serialize_by_value)
 | 
			
		||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
 | 
			
		||||
from vllm.usage.usage_lib import UsageContext
 | 
			
		||||
from vllm.utils import (FlexibleArgumentParser, get_open_zmq_ipc_path,
 | 
			
		||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
 | 
			
		||||
                        is_valid_ipv6_address, set_ulimit)
 | 
			
		||||
from vllm.version import __version__ as VLLM_VERSION
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -83,7 +101,7 @@ prometheus_multiproc_dir: tempfile.TemporaryDirectory
 | 
			
		|||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
 | 
			
		||||
logger = init_logger('vllm.entrypoints.openai.api_server')
 | 
			
		||||
 | 
			
		||||
_running_tasks: Set[asyncio.Task] = set()
 | 
			
		||||
_running_tasks: set[asyncio.Task] = set()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
| 
						 | 
				
			
			@ -102,6 +120,11 @@ async def lifespan(app: FastAPI):
 | 
			
		|||
            task.add_done_callback(_running_tasks.remove)
 | 
			
		||||
        else:
 | 
			
		||||
            task = None
 | 
			
		||||
 | 
			
		||||
        # Mark the startup heap as static so that it's ignored by GC.
 | 
			
		||||
        # Reduces pause times of oldest generation collections.
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        gc.freeze()
 | 
			
		||||
        try:
 | 
			
		||||
            yield
 | 
			
		||||
        finally:
 | 
			
		||||
| 
						 | 
				
			
			@ -139,24 +162,49 @@ async def build_async_engine_client_from_engine_args(
 | 
			
		|||
    Returns the Client or None if the creation failed.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Fall back
 | 
			
		||||
    # TODO: fill out feature matrix.
 | 
			
		||||
    if (MQLLMEngineClient.is_unsupported_config(engine_args)
 | 
			
		||||
            or envs.VLLM_USE_V1 or disable_frontend_multiprocessing):
 | 
			
		||||
    # Create the EngineConfig (determines if we can use V1).
 | 
			
		||||
    usage_context = UsageContext.OPENAI_API_SERVER
 | 
			
		||||
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)
 | 
			
		||||
 | 
			
		||||
    # V1 AsyncLLM.
 | 
			
		||||
    if envs.VLLM_USE_V1:
 | 
			
		||||
        if disable_frontend_multiprocessing:
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "V1 is enabled, but got --disable-frontend-multiprocessing. "
 | 
			
		||||
                "To disable frontend multiprocessing, set VLLM_USE_V1=0.")
 | 
			
		||||
 | 
			
		||||
        from ipex_llm.vllm.xpu.engine import IPEXLLMAsyncV1Engine as AsyncLLM
 | 
			
		||||
        async_llm: Optional[AsyncLLM] = None
 | 
			
		||||
        try:
 | 
			
		||||
            async_llm = AsyncLLM.from_vllm_config(
 | 
			
		||||
                vllm_config=vllm_config,
 | 
			
		||||
                usage_context=usage_context,
 | 
			
		||||
                disable_log_requests=engine_args.disable_log_requests,
 | 
			
		||||
                disable_log_stats=engine_args.disable_log_stats,
 | 
			
		||||
                load_in_low_bit=load_in_low_bit)
 | 
			
		||||
            yield async_llm
 | 
			
		||||
        finally:
 | 
			
		||||
            if async_llm:
 | 
			
		||||
                async_llm.shutdown()
 | 
			
		||||
 | 
			
		||||
    # V0 AsyncLLM.
 | 
			
		||||
    elif (MQLLMEngineClient.is_unsupported_config(vllm_config)
 | 
			
		||||
          or disable_frontend_multiprocessing):
 | 
			
		||||
 | 
			
		||||
        engine_client: Optional[EngineClient] = None
 | 
			
		||||
        try:
 | 
			
		||||
            # When starting this, we are actually starting with the V1Engine
 | 
			
		||||
            # Here we are doing a classification, we will need to do this in IPEX-LLM
 | 
			
		||||
            engine_client = AsyncLLMEngine.from_engine_args(
 | 
			
		||||
                engine_args=engine_args,
 | 
			
		||||
                usage_context=UsageContext.OPENAI_API_SERVER,
 | 
			
		||||
            engine_client = AsyncLLMEngine.from_vllm_config(
 | 
			
		||||
                vllm_config=vllm_config,
 | 
			
		||||
                usage_context=usage_context,
 | 
			
		||||
                disable_log_requests=engine_args.disable_log_requests,
 | 
			
		||||
                disable_log_stats=engine_args.disable_log_stats,
 | 
			
		||||
                load_in_low_bit=load_in_low_bit)
 | 
			
		||||
            yield engine_client
 | 
			
		||||
        finally:
 | 
			
		||||
            if engine_client and hasattr(engine_client, "shutdown"):
 | 
			
		||||
                engine_client.shutdown()
 | 
			
		||||
 | 
			
		||||
    # Otherwise, use the multiprocessing AsyncLLMEngine.
 | 
			
		||||
    # V0MQLLMEngine.
 | 
			
		||||
    else:
 | 
			
		||||
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
 | 
			
		||||
            # Make TemporaryDirectory for prometheus multiprocessing
 | 
			
		||||
| 
						 | 
				
			
			@ -183,14 +231,18 @@ async def build_async_engine_client_from_engine_args(
 | 
			
		|||
        # so we need to spawn a new process
 | 
			
		||||
        context = multiprocessing.get_context("spawn")
 | 
			
		||||
 | 
			
		||||
        # Ensure we can serialize transformer config before spawning
 | 
			
		||||
        maybe_register_config_serialize_by_value()
 | 
			
		||||
 | 
			
		||||
        # The Process can raise an exception during startup, which may
 | 
			
		||||
        # not actually result in an exitcode being reported. As a result
 | 
			
		||||
        # we use a shared variable to communicate the information.
 | 
			
		||||
        engine_alive = multiprocessing.Value('b', True, lock=False)
 | 
			
		||||
        engine_process = context.Process(target=run_mp_engine,
 | 
			
		||||
                                         args=(engine_args,
 | 
			
		||||
                                               UsageContext.OPENAI_API_SERVER,
 | 
			
		||||
                                               ipc_path, load_in_low_bit, engine_alive))
 | 
			
		||||
        engine_process = context.Process(
 | 
			
		||||
            target=run_mp_engine,
 | 
			
		||||
            args=(vllm_config, UsageContext.OPENAI_API_SERVER, ipc_path,
 | 
			
		||||
                  engine_args.disable_log_stats,
 | 
			
		||||
                  engine_args.disable_log_requests, load_in_low_bit, engine_alive))
 | 
			
		||||
        engine_process.start()
 | 
			
		||||
        engine_pid = engine_process.pid
 | 
			
		||||
        assert engine_pid is not None, "Engine process failed to start."
 | 
			
		||||
| 
						 | 
				
			
			@ -205,8 +257,7 @@ async def build_async_engine_client_from_engine_args(
 | 
			
		|||
        atexit.register(_cleanup_ipc_path)
 | 
			
		||||
 | 
			
		||||
        # Build RPCClient, which conforms to EngineClient Protocol.
 | 
			
		||||
        engine_config = engine_args.create_engine_config()
 | 
			
		||||
        build_client = partial(MQLLMEngineClient, ipc_path, engine_config,
 | 
			
		||||
        build_client = partial(MQLLMEngineClient, ipc_path, vllm_config,
 | 
			
		||||
                               engine_pid)
 | 
			
		||||
        mq_engine_client = await asyncio.get_running_loop().run_in_executor(
 | 
			
		||||
            None, build_client)
 | 
			
		||||
| 
						 | 
				
			
			@ -244,6 +295,43 @@ async def build_async_engine_client_from_engine_args(
 | 
			
		|||
            multiprocess.mark_process_dead(engine_process.pid)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def validate_json_request(raw_request: Request):
 | 
			
		||||
    content_type = raw_request.headers.get("content-type", "").lower()
 | 
			
		||||
    media_type = content_type.split(";", maxsplit=1)[0]
 | 
			
		||||
    if media_type != "application/json":
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
 | 
			
		||||
            detail="Unsupported Media Type: Only 'application/json' is allowed"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
save_dict = {}
 | 
			
		||||
import os
 | 
			
		||||
flag = os.getenv("VLLM_LOG_OUTPUT", None)
 | 
			
		||||
async def stream_generator(generator, request, request_id):
 | 
			
		||||
    async for chunk in generator:
 | 
			
		||||
        if request_id not in save_dict:
 | 
			
		||||
            save_dict[request_id] = ""
 | 
			
		||||
        import json
 | 
			
		||||
        try:
 | 
			
		||||
            data = chunk.strip()
 | 
			
		||||
            if data.startswith('data: '):
 | 
			
		||||
                data = data[len('data: '):]
 | 
			
		||||
            else:
 | 
			
		||||
                yield chunk
 | 
			
		||||
            json_data = json.loads(data)
 | 
			
		||||
            if 'choices' in json_data and len(json_data['choices']) > 0:
 | 
			
		||||
                choice = json_data['choices'][0]
 | 
			
		||||
                if 'delta' in choice:
 | 
			
		||||
                    save_dict[request_id] += choice["delta"]["content"]
 | 
			
		||||
                elif 'text' in choice:
 | 
			
		||||
                    save_dict[request_id] += choice["text"]
 | 
			
		||||
        except json.JSONDecodeError:
 | 
			
		||||
            print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
 | 
			
		||||
            pass  # Done
 | 
			
		||||
        yield chunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -254,6 +342,7 @@ def mount_metrics(app: FastAPI):
 | 
			
		|||
    # See https://prometheus.github.io/client_python/multiprocess/
 | 
			
		||||
    from prometheus_client import (CollectorRegistry, make_asgi_app,
 | 
			
		||||
                                   multiprocess)
 | 
			
		||||
    from prometheus_fastapi_instrumentator import Instrumentator
 | 
			
		||||
 | 
			
		||||
    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
 | 
			
		||||
    if prometheus_multiproc_dir_path is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -261,6 +350,16 @@ def mount_metrics(app: FastAPI):
 | 
			
		|||
                     prometheus_multiproc_dir_path)
 | 
			
		||||
        registry = CollectorRegistry()
 | 
			
		||||
        multiprocess.MultiProcessCollector(registry)
 | 
			
		||||
        Instrumentator(
 | 
			
		||||
            excluded_handlers=[
 | 
			
		||||
                "/metrics",
 | 
			
		||||
                "/health",
 | 
			
		||||
                "/load",
 | 
			
		||||
                "/ping",
 | 
			
		||||
                "/version",
 | 
			
		||||
            ],
 | 
			
		||||
            registry=registry,
 | 
			
		||||
        ).add().instrument(app).expose(app)
 | 
			
		||||
 | 
			
		||||
        # Add prometheus asgi middleware to route /metrics requests
 | 
			
		||||
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
 | 
			
		||||
| 
						 | 
				
			
			@ -298,7 +397,11 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
 | 
			
		|||
    return request.app.state.openai_serving_embedding
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def score(request: Request) -> Optional[OpenAIServingScores]:
 | 
			
		||||
def score(request: Request) -> Optional[ServingScores]:
 | 
			
		||||
    return request.app.state.openai_serving_scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rerank(request: Request) -> Optional[ServingScores]:
 | 
			
		||||
    return request.app.state.openai_serving_scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -306,6 +409,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
 | 
			
		|||
    return request.app.state.openai_serving_tokenization
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def transcription(request: Request) -> OpenAIServingTranscription:
 | 
			
		||||
    return request.app.state.openai_serving_transcription
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def engine_client(request: Request) -> EngineClient:
 | 
			
		||||
    return request.app.state.engine_client
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -317,7 +424,31 @@ async def health(raw_request: Request) -> Response:
 | 
			
		|||
    return Response(status_code=200)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/tokenize")
 | 
			
		||||
@router.get("/load")
 | 
			
		||||
async def get_server_load_metrics(request: Request):
 | 
			
		||||
    # This endpoint returns the current server load metrics.
 | 
			
		||||
    # It tracks requests utilizing the GPU from the following routes:
 | 
			
		||||
    # - /v1/chat/completions
 | 
			
		||||
    # - /v1/completions
 | 
			
		||||
    # - /v1/audio/transcriptions
 | 
			
		||||
    # - /v1/embeddings
 | 
			
		||||
    # - /pooling
 | 
			
		||||
    # - /score
 | 
			
		||||
    # - /v1/score
 | 
			
		||||
    # - /rerank
 | 
			
		||||
    # - /v1/rerank
 | 
			
		||||
    # - /v2/rerank
 | 
			
		||||
    return JSONResponse(
 | 
			
		||||
        content={'server_load': request.app.state.server_load_metrics})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.api_route("/ping", methods=["GET", "POST"])
 | 
			
		||||
async def ping(raw_request: Request) -> Response:
 | 
			
		||||
    """Ping check. Endpoint required for SageMaker"""
 | 
			
		||||
    return await health(raw_request)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/tokenize", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
 | 
			
		||||
    handler = tokenization(raw_request)
 | 
			
		||||
| 
						 | 
				
			
			@ -332,7 +463,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
 | 
			
		|||
    assert_never(generator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/detokenize")
 | 
			
		||||
@router.post("/detokenize", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
 | 
			
		||||
    handler = tokenization(raw_request)
 | 
			
		||||
| 
						 | 
				
			
			@ -361,35 +492,10 @@ async def show_version():
 | 
			
		|||
    return JSONResponse(content=ver)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
save_dict = {}
 | 
			
		||||
import os
 | 
			
		||||
flag = os.getenv("VLLM_LOG_OUTPUT", None)
 | 
			
		||||
async def stream_generator(generator, request, request_id):
 | 
			
		||||
    async for chunk in generator:
 | 
			
		||||
        if request_id not in save_dict:
 | 
			
		||||
            save_dict[request_id] = ""
 | 
			
		||||
        import json
 | 
			
		||||
        try:
 | 
			
		||||
            data = chunk.strip()
 | 
			
		||||
            if data.startswith('data: '):
 | 
			
		||||
                data = data[len('data: '):]
 | 
			
		||||
            else:
 | 
			
		||||
                yield chunk
 | 
			
		||||
            json_data = json.loads(data)
 | 
			
		||||
            if 'choices' in json_data and len(json_data['choices']) > 0:
 | 
			
		||||
                choice = json_data['choices'][0]
 | 
			
		||||
                if 'delta' in choice:
 | 
			
		||||
                    save_dict[request_id] += choice["delta"]["content"]
 | 
			
		||||
                elif 'text' in choice:
 | 
			
		||||
                    save_dict[request_id] += choice["text"]
 | 
			
		||||
        except json.JSONDecodeError:
 | 
			
		||||
            print(f"Received request_id: {request_id}, request: {request} content: {save_dict[request_id]}")
 | 
			
		||||
            pass  # Done
 | 
			
		||||
        yield chunk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/chat/completions")
 | 
			
		||||
@router.post("/v1/chat/completions",
 | 
			
		||||
             dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_chat_completion(request: ChatCompletionRequest,
 | 
			
		||||
                                 raw_request: Request):
 | 
			
		||||
    handler = chat(raw_request)
 | 
			
		||||
| 
						 | 
				
			
			@ -401,7 +507,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
 | 
			
		|||
        request_id = "chatcmpl-" \
 | 
			
		||||
                f"{handler._base_request_id(raw_request, request.request_id)}"
 | 
			
		||||
        print(f"First received request_id: {request_id}, request: {request}")
 | 
			
		||||
    
 | 
			
		||||
    generator = await handler.create_chat_completion(request, raw_request)
 | 
			
		||||
 | 
			
		||||
    if isinstance(generator, ErrorResponse):
 | 
			
		||||
| 
						 | 
				
			
			@ -418,8 +523,9 @@ async def create_chat_completion(request: ChatCompletionRequest,
 | 
			
		|||
    return StreamingResponse(content=generator, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/completions")
 | 
			
		||||
@router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_completion(request: CompletionRequest, raw_request: Request):
 | 
			
		||||
    handler = completion(raw_request)
 | 
			
		||||
    if handler is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -438,14 +544,15 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
 | 
			
		|||
        if flag is not None:
 | 
			
		||||
            print(f"Received request-id:{request_id}, request:{request}, Output:{generator.model_dump()}")
 | 
			
		||||
        return JSONResponse(content=generator.model_dump())
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
    if flag is not None:
 | 
			
		||||
        return StreamingResponse(content=stream_generator(generator, request, request_id), media_type="text/event-stream")
 | 
			
		||||
    return StreamingResponse(content=generator, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/embeddings")
 | 
			
		||||
@router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 | 
			
		||||
    handler = embedding(raw_request)
 | 
			
		||||
    if handler is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -460,6 +567,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 | 
			
		|||
            "use the Pooling API (`/pooling`) instead.")
 | 
			
		||||
 | 
			
		||||
        res = await fallback_handler.create_pooling(request, raw_request)
 | 
			
		||||
 | 
			
		||||
        generator: Union[ErrorResponse, EmbeddingResponse]
 | 
			
		||||
        if isinstance(res, PoolingResponse):
 | 
			
		||||
            generator = EmbeddingResponse(
 | 
			
		||||
                id=res.id,
 | 
			
		||||
| 
						 | 
				
			
			@ -488,8 +597,9 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
 | 
			
		|||
    assert_never(generator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/pooling")
 | 
			
		||||
@router.post("/pooling", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_pooling(request: PoolingRequest, raw_request: Request):
 | 
			
		||||
    handler = pooling(raw_request)
 | 
			
		||||
    if handler is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -506,8 +616,9 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
 | 
			
		|||
    assert_never(generator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/score")
 | 
			
		||||
@router.post("/score", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_score(request: ScoreRequest, raw_request: Request):
 | 
			
		||||
    handler = score(raw_request)
 | 
			
		||||
    if handler is None:
 | 
			
		||||
| 
						 | 
				
			
			@ -524,8 +635,9 @@ async def create_score(request: ScoreRequest, raw_request: Request):
 | 
			
		|||
    assert_never(generator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/score")
 | 
			
		||||
@router.post("/v1/score", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_score_v1(request: ScoreRequest, raw_request: Request):
 | 
			
		||||
    logger.warning(
 | 
			
		||||
        "To indicate that Score API is not part of standard OpenAI API, we "
 | 
			
		||||
| 
						 | 
				
			
			@ -534,6 +646,160 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
 | 
			
		|||
    return await create_score(request, raw_request)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/audio/transcriptions")
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def create_transcriptions(request: Annotated[TranscriptionRequest,
 | 
			
		||||
                                                   Form()],
 | 
			
		||||
                                raw_request: Request):
 | 
			
		||||
    handler = transcription(raw_request)
 | 
			
		||||
    if handler is None:
 | 
			
		||||
        return base(raw_request).create_error_response(
 | 
			
		||||
            message="The model does not support Transcriptions API")
 | 
			
		||||
 | 
			
		||||
    audio_data = await request.file.read()
 | 
			
		||||
    generator = await handler.create_transcription(audio_data, request,
 | 
			
		||||
                                                   raw_request)
 | 
			
		||||
 | 
			
		||||
    if isinstance(generator, ErrorResponse):
 | 
			
		||||
        return JSONResponse(content=generator.model_dump(),
 | 
			
		||||
                            status_code=generator.code)
 | 
			
		||||
 | 
			
		||||
    elif isinstance(generator, TranscriptionResponse):
 | 
			
		||||
        return JSONResponse(content=generator.model_dump())
 | 
			
		||||
 | 
			
		||||
    return StreamingResponse(content=generator, media_type="text/event-stream")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
@load_aware_call
 | 
			
		||||
async def do_rerank(request: RerankRequest, raw_request: Request):
 | 
			
		||||
    handler = rerank(raw_request)
 | 
			
		||||
    if handler is None:
 | 
			
		||||
        return base(raw_request).create_error_response(
 | 
			
		||||
            message="The model does not support Rerank (Score) API")
 | 
			
		||||
    generator = await handler.do_rerank(request, raw_request)
 | 
			
		||||
    if isinstance(generator, ErrorResponse):
 | 
			
		||||
        return JSONResponse(content=generator.model_dump(),
 | 
			
		||||
                            status_code=generator.code)
 | 
			
		||||
    elif isinstance(generator, RerankResponse):
 | 
			
		||||
        return JSONResponse(content=generator.model_dump())
 | 
			
		||||
 | 
			
		||||
    assert_never(generator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
 | 
			
		||||
    logger.warning_once(
 | 
			
		||||
        "To indicate that the rerank API is not part of the standard OpenAI"
 | 
			
		||||
        " API, we have located it at `/rerank`. Please update your client "
 | 
			
		||||
        "accordingly. (Note: Conforms to JinaAI rerank API)")
 | 
			
		||||
 | 
			
		||||
    return await do_rerank(request, raw_request)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
@with_cancellation
 | 
			
		||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
 | 
			
		||||
    return await do_rerank(request, raw_request)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TASK_HANDLERS: dict[str, dict[str, tuple]] = {
 | 
			
		||||
    "generate": {
 | 
			
		||||
        "messages": (ChatCompletionRequest, create_chat_completion),
 | 
			
		||||
        "default": (CompletionRequest, create_completion),
 | 
			
		||||
    },
 | 
			
		||||
    "embed": {
 | 
			
		||||
        "messages": (EmbeddingChatRequest, create_embedding),
 | 
			
		||||
        "default": (EmbeddingCompletionRequest, create_embedding),
 | 
			
		||||
    },
 | 
			
		||||
    "score": {
 | 
			
		||||
        "default": (RerankRequest, do_rerank)
 | 
			
		||||
    },
 | 
			
		||||
    "rerank": {
 | 
			
		||||
        "default": (RerankRequest, do_rerank)
 | 
			
		||||
    },
 | 
			
		||||
    "reward": {
 | 
			
		||||
        "messages": (PoolingChatRequest, create_pooling),
 | 
			
		||||
        "default": (PoolingCompletionRequest, create_pooling),
 | 
			
		||||
    },
 | 
			
		||||
    "classify": {
 | 
			
		||||
        "messages": (PoolingChatRequest, create_pooling),
 | 
			
		||||
        "default": (PoolingCompletionRequest, create_pooling),
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
if envs.VLLM_SERVER_DEV_MODE:
 | 
			
		||||
 | 
			
		||||
    @router.post("/reset_prefix_cache")
 | 
			
		||||
    async def reset_prefix_cache(raw_request: Request):
 | 
			
		||||
        """
 | 
			
		||||
        Reset the prefix cache. Note that we currently do not check if the
 | 
			
		||||
        prefix cache is successfully reset in the API server.
 | 
			
		||||
        """
 | 
			
		||||
        device = None
 | 
			
		||||
        device_str = raw_request.query_params.get("device")
 | 
			
		||||
        if device_str is not None:
 | 
			
		||||
            device = Device[device_str.upper()]
 | 
			
		||||
        logger.info("Resetting prefix cache with specific %s...", str(device))
 | 
			
		||||
        await engine_client(raw_request).reset_prefix_cache(device)
 | 
			
		||||
        return Response(status_code=200)
 | 
			
		||||
 | 
			
		||||
    @router.post("/sleep")
 | 
			
		||||
    async def sleep(raw_request: Request):
 | 
			
		||||
        # get POST params
 | 
			
		||||
        level = raw_request.query_params.get("level", "1")
 | 
			
		||||
        await engine_client(raw_request).sleep(int(level))
 | 
			
		||||
        # FIXME: in v0 with frontend multiprocessing, the sleep command
 | 
			
		||||
        # is sent but does not finish yet when we return a response.
 | 
			
		||||
        return Response(status_code=200)
 | 
			
		||||
 | 
			
		||||
    @router.post("/wake_up")
 | 
			
		||||
    async def wake_up(raw_request: Request):
 | 
			
		||||
        tags = raw_request.query_params.getlist("tags")
 | 
			
		||||
        if tags == []:
 | 
			
		||||
            # set to None to wake up all tags if no tags are provided
 | 
			
		||||
            tags = None
 | 
			
		||||
        logger.info("wake up the engine with tags: %s", tags)
 | 
			
		||||
        await engine_client(raw_request).wake_up(tags)
 | 
			
		||||
        # FIXME: in v0 with frontend multiprocessing, the wake-up command
 | 
			
		||||
        # is sent but does not finish yet when we return a response.
 | 
			
		||||
        return Response(status_code=200)
 | 
			
		||||
 | 
			
		||||
    @router.get("/is_sleeping")
 | 
			
		||||
    async def is_sleeping(raw_request: Request):
 | 
			
		||||
        logger.info("check whether the engine is sleeping")
 | 
			
		||||
        is_sleeping = await engine_client(raw_request).is_sleeping()
 | 
			
		||||
        return JSONResponse(content={"is_sleeping": is_sleeping})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("/invocations", dependencies=[Depends(validate_json_request)])
 | 
			
		||||
async def invocations(raw_request: Request):
 | 
			
		||||
    """
 | 
			
		||||
    For SageMaker, routes requests to other handlers based on model `task`.
 | 
			
		||||
    """
 | 
			
		||||
    body = await raw_request.json()
 | 
			
		||||
    task = raw_request.app.state.task
 | 
			
		||||
 | 
			
		||||
    if task not in TASK_HANDLERS:
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=400,
 | 
			
		||||
            detail=f"Unsupported task: '{task}' for '/invocations'. "
 | 
			
		||||
            f"Expected one of {set(TASK_HANDLERS.keys())}")
 | 
			
		||||
 | 
			
		||||
    handler_config = TASK_HANDLERS[task]
 | 
			
		||||
    if "messages" in body:
 | 
			
		||||
        request_model, handler = handler_config["messages"]
 | 
			
		||||
    else:
 | 
			
		||||
        request_model, handler = handler_config["default"]
 | 
			
		||||
 | 
			
		||||
    # this is required since we lose the FastAPI automatic casting
 | 
			
		||||
    request = request_model.model_validate(body)
 | 
			
		||||
    return await handler(request, raw_request)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if envs.VLLM_TORCH_PROFILER_DIR:
 | 
			
		||||
    logger.warning(
 | 
			
		||||
        "Torch Profiler is enabled in the API server. This should ONLY be "
 | 
			
		||||
| 
						 | 
				
			
			@ -556,32 +822,30 @@ if envs.VLLM_TORCH_PROFILER_DIR:
 | 
			
		|||
 | 
			
		||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
 | 
			
		||||
    logger.warning(
 | 
			
		||||
        "Lora dynamic loading & unloading is enabled in the API server. "
 | 
			
		||||
        "LoRA dynamic loading & unloading is enabled in the API server. "
 | 
			
		||||
        "This should ONLY be used for local development!")
 | 
			
		||||
 | 
			
		||||
    @router.post("/v1/load_lora_adapter")
 | 
			
		||||
    async def load_lora_adapter(request: LoadLoraAdapterRequest,
 | 
			
		||||
    @router.post("/v1/load_lora_adapter",
 | 
			
		||||
                 dependencies=[Depends(validate_json_request)])
 | 
			
		||||
    async def load_lora_adapter(request: LoadLoRAAdapterRequest,
 | 
			
		||||
                                raw_request: Request):
 | 
			
		||||
        for route in [chat, completion, embedding]:
 | 
			
		||||
            handler = route(raw_request)
 | 
			
		||||
            if handler is not None:
 | 
			
		||||
                response = await handler.load_lora_adapter(request)
 | 
			
		||||
                if isinstance(response, ErrorResponse):
 | 
			
		||||
                    return JSONResponse(content=response.model_dump(),
 | 
			
		||||
                                        status_code=response.code)
 | 
			
		||||
        handler = models(raw_request)
 | 
			
		||||
        response = await handler.load_lora_adapter(request)
 | 
			
		||||
        if isinstance(response, ErrorResponse):
 | 
			
		||||
            return JSONResponse(content=response.model_dump(),
 | 
			
		||||
                                status_code=response.code)
 | 
			
		||||
 | 
			
		||||
        return Response(status_code=200, content=response)
 | 
			
		||||
 | 
			
		||||
    @router.post("/v1/unload_lora_adapter")
 | 
			
		||||
    async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
 | 
			
		||||
    @router.post("/v1/unload_lora_adapter",
 | 
			
		||||
                 dependencies=[Depends(validate_json_request)])
 | 
			
		||||
    async def unload_lora_adapter(request: UnloadLoRAAdapterRequest,
 | 
			
		||||
                                  raw_request: Request):
 | 
			
		||||
        for route in [chat, completion, embedding]:
 | 
			
		||||
            handler = route(raw_request)
 | 
			
		||||
            if handler is not None:
 | 
			
		||||
                response = await handler.unload_lora_adapter(request)
 | 
			
		||||
                if isinstance(response, ErrorResponse):
 | 
			
		||||
                    return JSONResponse(content=response.model_dump(),
 | 
			
		||||
                                        status_code=response.code)
 | 
			
		||||
        handler = models(raw_request)
 | 
			
		||||
        response = await handler.unload_lora_adapter(request)
 | 
			
		||||
        if isinstance(response, ErrorResponse):
 | 
			
		||||
            return JSONResponse(content=response.model_dump(),
 | 
			
		||||
                                status_code=response.code)
 | 
			
		||||
 | 
			
		||||
        return Response(status_code=200, content=response)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -615,7 +879,8 @@ def build_app(args: Namespace) -> FastAPI:
 | 
			
		|||
        return JSONResponse(err.model_dump(),
 | 
			
		||||
                            status_code=HTTPStatus.BAD_REQUEST)
 | 
			
		||||
 | 
			
		||||
    if token := envs.VLLM_API_KEY or args.api_key:
 | 
			
		||||
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
 | 
			
		||||
    if token := args.api_key or envs.VLLM_API_KEY:
 | 
			
		||||
 | 
			
		||||
        @app.middleware("http")
 | 
			
		||||
        async def authentication(request: Request, call_next):
 | 
			
		||||
| 
						 | 
				
			
			@ -644,11 +909,26 @@ def build_app(args: Namespace) -> FastAPI:
 | 
			
		|||
            response.headers["X-Request-Id"] = request_id
 | 
			
		||||
            return response
 | 
			
		||||
 | 
			
		||||
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
 | 
			
		||||
        logger.warning("CAUTION: Enabling log response in the API Server. "
 | 
			
		||||
                       "This can include sensitive information and should be "
 | 
			
		||||
                       "avoided in production.")
 | 
			
		||||
 | 
			
		||||
        @app.middleware("http")
 | 
			
		||||
        async def log_response(request: Request, call_next):
 | 
			
		||||
            response = await call_next(request)
 | 
			
		||||
            response_body = [
 | 
			
		||||
                section async for section in response.body_iterator
 | 
			
		||||
            ]
 | 
			
		||||
            response.body_iterator = iterate_in_threadpool(iter(response_body))
 | 
			
		||||
            logger.info("response_body={%s}", response_body[0].decode())
 | 
			
		||||
            return response
 | 
			
		||||
 | 
			
		||||
    for middleware in args.middleware:
 | 
			
		||||
        module_path, object_name = middleware.rsplit(".", 1)
 | 
			
		||||
        imported = getattr(importlib.import_module(module_path), object_name)
 | 
			
		||||
        if inspect.isclass(imported):
 | 
			
		||||
            app.add_middleware(imported)
 | 
			
		||||
            app.add_middleware(imported)  # type: ignore[arg-type]
 | 
			
		||||
        elif inspect.iscoroutinefunction(imported):
 | 
			
		||||
            app.middleware("http")(imported)
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			@ -658,7 +938,7 @@ def build_app(args: Namespace) -> FastAPI:
 | 
			
		|||
    return app
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_app_state(
 | 
			
		||||
async def init_app_state(
 | 
			
		||||
    engine_client: EngineClient,
 | 
			
		||||
    model_config: ModelConfig,
 | 
			
		||||
    state: State,
 | 
			
		||||
| 
						 | 
				
			
			@ -683,15 +963,36 @@ def init_app_state(
 | 
			
		|||
    state.log_stats = not args.disable_log_stats
 | 
			
		||||
 | 
			
		||||
    resolved_chat_template = load_chat_template(args.chat_template)
 | 
			
		||||
    logger.info("Using supplied chat template:\n%s", resolved_chat_template)
 | 
			
		||||
    if resolved_chat_template is not None:
 | 
			
		||||
        # Get the tokenizer to check official template
 | 
			
		||||
        tokenizer = await engine_client.get_tokenizer()
 | 
			
		||||
 | 
			
		||||
        if isinstance(tokenizer, MistralTokenizer):
 | 
			
		||||
            # The warning is logged in resolve_mistral_chat_template.
 | 
			
		||||
            resolved_chat_template = resolve_mistral_chat_template(
 | 
			
		||||
                chat_template=resolved_chat_template)
 | 
			
		||||
        else:
 | 
			
		||||
            hf_chat_template = resolve_hf_chat_template(
 | 
			
		||||
                tokenizer,
 | 
			
		||||
                chat_template=None,
 | 
			
		||||
                tools=None,
 | 
			
		||||
                trust_remote_code=model_config.trust_remote_code)
 | 
			
		||||
 | 
			
		||||
            if hf_chat_template != resolved_chat_template:
 | 
			
		||||
                logger.warning(
 | 
			
		||||
                    "Using supplied chat template: %s\n"
 | 
			
		||||
                    "It is different from official chat template '%s'. "
 | 
			
		||||
                    "This discrepancy may lead to performance degradation.",
 | 
			
		||||
                    resolved_chat_template, args.model)
 | 
			
		||||
 | 
			
		||||
    state.openai_serving_models = OpenAIServingModels(
 | 
			
		||||
        engine_client=engine_client,
 | 
			
		||||
        model_config=model_config,
 | 
			
		||||
        base_model_paths=base_model_paths,
 | 
			
		||||
        lora_modules=args.lora_modules,
 | 
			
		||||
        prompt_adapters=args.prompt_adapters,
 | 
			
		||||
    )
 | 
			
		||||
    # TODO: The chat template is now broken for lora adapters :(
 | 
			
		||||
    await state.openai_serving_models.init_static_loras()
 | 
			
		||||
    state.openai_serving_chat = OpenAIServingChat(
 | 
			
		||||
        engine_client,
 | 
			
		||||
        model_config,
 | 
			
		||||
| 
						 | 
				
			
			@ -703,6 +1004,8 @@ def init_app_state(
 | 
			
		|||
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
			
		||||
        enable_auto_tools=args.enable_auto_tool_choice,
 | 
			
		||||
        tool_parser=args.tool_call_parser,
 | 
			
		||||
        enable_reasoning=args.enable_reasoning,
 | 
			
		||||
        reasoning_parser=args.reasoning_parser,
 | 
			
		||||
        enable_prompt_tokens_details=args.enable_prompt_tokens_details,
 | 
			
		||||
    ) if model_config.runner_type == "generate" else None
 | 
			
		||||
    state.openai_serving_completion = OpenAIServingCompletion(
 | 
			
		||||
| 
						 | 
				
			
			@ -728,7 +1031,13 @@ def init_app_state(
 | 
			
		|||
        chat_template=resolved_chat_template,
 | 
			
		||||
        chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
    ) if model_config.task == "embed" else None
 | 
			
		||||
    state.openai_serving_scores = OpenAIServingScores(
 | 
			
		||||
    state.openai_serving_scores = ServingScores(
 | 
			
		||||
        engine_client,
 | 
			
		||||
        model_config,
 | 
			
		||||
        state.openai_serving_models,
 | 
			
		||||
        request_logger=request_logger) if model_config.task in (
 | 
			
		||||
            "score", "embed", "pooling") else None
 | 
			
		||||
    state.jinaai_serving_reranking = ServingScores(
 | 
			
		||||
        engine_client,
 | 
			
		||||
        model_config,
 | 
			
		||||
        state.openai_serving_models,
 | 
			
		||||
| 
						 | 
				
			
			@ -742,92 +1051,26 @@ def init_app_state(
 | 
			
		|||
        chat_template=resolved_chat_template,
 | 
			
		||||
        chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
    )
 | 
			
		||||
    state.openai_serving_transcription = OpenAIServingTranscription(
 | 
			
		||||
        engine_client,
 | 
			
		||||
        model_config,
 | 
			
		||||
        state.openai_serving_models,
 | 
			
		||||
        request_logger=request_logger,
 | 
			
		||||
    ) if model_config.runner_type == "transcription" else None
 | 
			
		||||
    state.task = model_config.task
 | 
			
		||||
    # if args.served_model_name is not None:
 | 
			
		||||
    #     served_model_names = args.served_model_name
 | 
			
		||||
    # else:
 | 
			
		||||
    #     served_model_names = [args.model]
 | 
			
		||||
 | 
			
		||||
    # if args.disable_log_requests:
 | 
			
		||||
    #     request_logger = None
 | 
			
		||||
    # else:
 | 
			
		||||
    #     request_logger = RequestLogger(max_log_len=args.max_log_len)
 | 
			
		||||
 | 
			
		||||
    # base_model_paths = [
 | 
			
		||||
    #     BaseModelPath(name=name, model_path=args.model)
 | 
			
		||||
    #     for name in served_model_names
 | 
			
		||||
    # ]
 | 
			
		||||
 | 
			
		||||
    # state.engine_client = engine_client
 | 
			
		||||
    # state.log_stats = not args.disable_log_stats
 | 
			
		||||
 | 
			
		||||
    # resolved_chat_template = load_chat_template(args.chat_template)
 | 
			
		||||
    # logger.info("Using supplied chat template:\n%s", resolved_chat_template)
 | 
			
		||||
 | 
			
		||||
    # state.openai_serving_chat = OpenAIServingChat(
 | 
			
		||||
    #     engine_client,
 | 
			
		||||
    #     model_config,
 | 
			
		||||
    #     base_model_paths,
 | 
			
		||||
    #     args.response_role,
 | 
			
		||||
    #     lora_modules=args.lora_modules,
 | 
			
		||||
    #     prompt_adapters=args.prompt_adapters,
 | 
			
		||||
    #     request_logger=request_logger,
 | 
			
		||||
    #     chat_template=resolved_chat_template,
 | 
			
		||||
    #     chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
    #     return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
			
		||||
    #     enable_auto_tools=args.enable_auto_tool_choice,
 | 
			
		||||
    #     tool_parser=args.tool_call_parser,
 | 
			
		||||
    #     enable_prompt_tokens_details=args.enable_prompt_tokens_details,
 | 
			
		||||
    # ) if model_config.runner_type == "generate" else None
 | 
			
		||||
    # state.openai_serving_completion = OpenAIServingCompletion(
 | 
			
		||||
    #     engine_client,
 | 
			
		||||
    #     model_config,
 | 
			
		||||
    #     base_model_paths,
 | 
			
		||||
    #     lora_modules=args.lora_modules,
 | 
			
		||||
    #     prompt_adapters=args.prompt_adapters,
 | 
			
		||||
    #     request_logger=request_logger,
 | 
			
		||||
    #     return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
			
		||||
    # ) if model_config.runner_type == "generate" else None
 | 
			
		||||
    # state.openai_serving_pooling = OpenAIServingPooling(
 | 
			
		||||
    #     engine_client,
 | 
			
		||||
    #     model_config,
 | 
			
		||||
    #     base_model_paths,
 | 
			
		||||
    #     request_logger=request_logger,
 | 
			
		||||
    #     chat_template=resolved_chat_template,
 | 
			
		||||
    #     chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
    # ) if model_config.runner_type == "pooling" else None
 | 
			
		||||
    # state.openai_serving_embedding = OpenAIServingEmbedding(
 | 
			
		||||
    #     engine_client,
 | 
			
		||||
    #     model_config,
 | 
			
		||||
    #     base_model_paths,
 | 
			
		||||
    #     request_logger=request_logger,
 | 
			
		||||
    #     chat_template=resolved_chat_template,
 | 
			
		||||
    #     chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
    # ) if model_config.task == "embed" else None
 | 
			
		||||
    # state.openai_serving_scores = OpenAIServingScores(
 | 
			
		||||
    #     engine_client,
 | 
			
		||||
    #     model_config,
 | 
			
		||||
    #     base_model_paths,
 | 
			
		||||
    #     request_logger=request_logger
 | 
			
		||||
    # ) if model_config.task == "score" else None
 | 
			
		||||
    # state.openai_serving_tokenization = OpenAIServingTokenization(
 | 
			
		||||
    #     engine_client,
 | 
			
		||||
    #     model_config,
 | 
			
		||||
    #     base_model_paths,
 | 
			
		||||
    #     lora_modules=args.lora_modules,
 | 
			
		||||
    #     request_logger=request_logger,
 | 
			
		||||
    #     chat_template=resolved_chat_template,
 | 
			
		||||
    #     chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
    # )
 | 
			
		||||
    state.enable_server_load_tracking = args.enable_server_load_tracking
 | 
			
		||||
    state.server_load_metrics = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
 | 
			
		||||
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
 | 
			
		||||
    family = socket.AF_INET
 | 
			
		||||
    if is_valid_ipv6_address(addr[0]):
 | 
			
		||||
        family = socket.AF_INET6
 | 
			
		||||
 | 
			
		||||
    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
 | 
			
		||||
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | 
			
		||||
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
 | 
			
		||||
    sock.bind(addr)
 | 
			
		||||
 | 
			
		||||
    return sock
 | 
			
		||||
| 
						 | 
				
			
			@ -840,11 +1083,18 @@ async def run_server(args, **uvicorn_kwargs) -> None:
 | 
			
		|||
    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
 | 
			
		||||
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)
 | 
			
		||||
 | 
			
		||||
    valide_tool_parses = ToolParserManager.tool_parsers.keys()
 | 
			
		||||
    valid_tool_parses = ToolParserManager.tool_parsers.keys()
 | 
			
		||||
    if args.enable_auto_tool_choice \
 | 
			
		||||
        and args.tool_call_parser not in valide_tool_parses:
 | 
			
		||||
        and args.tool_call_parser not in valid_tool_parses:
 | 
			
		||||
        raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
 | 
			
		||||
                       f"(chose from {{ {','.join(valide_tool_parses)} }})")
 | 
			
		||||
                       f"(chose from {{ {','.join(valid_tool_parses)} }})")
 | 
			
		||||
 | 
			
		||||
    valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
 | 
			
		||||
    if args.enable_reasoning \
 | 
			
		||||
        and args.reasoning_parser not in valid_reasoning_parses:
 | 
			
		||||
        raise KeyError(
 | 
			
		||||
            f"invalid reasoning parser: {args.reasoning_parser} "
 | 
			
		||||
            f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
 | 
			
		||||
 | 
			
		||||
    # workaround to make sure that we bind the port before the engine is set up.
 | 
			
		||||
    # This avoids race conditions with ray.
 | 
			
		||||
| 
						 | 
				
			
			@ -866,13 +1116,28 @@ async def run_server(args, **uvicorn_kwargs) -> None:
 | 
			
		|||
        app = build_app(args)
 | 
			
		||||
 | 
			
		||||
        model_config = await engine_client.get_model_config()
 | 
			
		||||
        init_app_state(engine_client, model_config, app.state, args)
 | 
			
		||||
        await init_app_state(engine_client, model_config, app.state, args)
 | 
			
		||||
 | 
			
		||||
        def _listen_addr(a: str) -> str:
 | 
			
		||||
            if is_valid_ipv6_address(a):
 | 
			
		||||
                return '[' + a + ']'
 | 
			
		||||
            return a or "0.0.0.0"
 | 
			
		||||
 | 
			
		||||
        is_ssl = args.ssl_keyfile and args.ssl_certfile
 | 
			
		||||
        logger.info("Starting vLLM API server on http%s://%s:%d",
 | 
			
		||||
                    "s" if is_ssl else "", _listen_addr(sock_addr[0]),
 | 
			
		||||
                    sock_addr[1])
 | 
			
		||||
 | 
			
		||||
        shutdown_task = await serve_http(
 | 
			
		||||
            app,
 | 
			
		||||
            sock=sock,
 | 
			
		||||
            enable_ssl_refresh=args.enable_ssl_refresh,
 | 
			
		||||
            host=args.host,
 | 
			
		||||
            port=args.port,
 | 
			
		||||
            log_level=args.uvicorn_log_level,
 | 
			
		||||
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
 | 
			
		||||
            # no access log will be output.
 | 
			
		||||
            access_log=not args.disable_uvicorn_access_log,
 | 
			
		||||
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
 | 
			
		||||
            ssl_keyfile=args.ssl_keyfile,
 | 
			
		||||
            ssl_certfile=args.ssl_certfile,
 | 
			
		||||
| 
						 | 
				
			
			@ -882,16 +1147,19 @@ async def run_server(args, **uvicorn_kwargs) -> None:
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    # NB: Await server shutdown only after the backend context is exited
 | 
			
		||||
    await shutdown_task
 | 
			
		||||
 | 
			
		||||
    sock.close()
 | 
			
		||||
    try:
 | 
			
		||||
        await shutdown_task
 | 
			
		||||
    finally:
 | 
			
		||||
        sock.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # NOTE(simon):
 | 
			
		||||
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
 | 
			
		||||
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
 | 
			
		||||
    # entrypoints.
 | 
			
		||||
    logger.warning("Warning: Please use `ipex_llm.vllm.xpu.entrypoints.openai.api_server` "
 | 
			
		||||
                   "instead of `vllm.entrypoints.openai.api_server` to start the API server")
 | 
			
		||||
    cli_env_setup()
 | 
			
		||||
    parser = FlexibleArgumentParser(
 | 
			
		||||
        description="vLLM OpenAI-Compatible RESTful API server.")
 | 
			
		||||
    parser = make_arg_parser(parser)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,7 +48,7 @@ def _sample_get_logits(
 | 
			
		|||
        logits = lm_head(hidden_states)
 | 
			
		||||
        if embedding_bias is not None:
 | 
			
		||||
            logits += embedding_bias
 | 
			
		||||
    if self.use_gather:
 | 
			
		||||
    if self.use_all_gather:
 | 
			
		||||
        logits = tensor_model_parallel_gather(logits)
 | 
			
		||||
    else:
 | 
			
		||||
        logits = tensor_model_parallel_all_gather(logits)
 | 
			
		||||
| 
						 | 
				
			
			@ -63,6 +63,8 @@ def _model_sample_convert():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def _ipex_llm_convert(load_in_low_bit):
 | 
			
		||||
    # import pdb
 | 
			
		||||
    # pdb.set_trace()
 | 
			
		||||
    from vllm.worker.xpu_model_runner import XPUModelRunner
 | 
			
		||||
    from ipex_llm.vllm.xpu.ipex_llm_wrapper import get_ipex_llm_wrapper
 | 
			
		||||
    from ipex_llm.vllm.xpu.ipex_llm_v1_wrapper import get_ipex_llm_v1_wrapper
 | 
			
		||||
| 
						 | 
				
			
			@ -99,7 +101,8 @@ def get_load_function(low_bit):
 | 
			
		|||
                        "codegeex4-all" in self.vllm_config.model_config.model.lower() or
 | 
			
		||||
                        "chatglm" in self.vllm_config.model_config.model.lower()) and \
 | 
			
		||||
                        "gptq" not in self.model_config.model.lower() and \
 | 
			
		||||
                        "awq" not in self.model_config.model.lower():
 | 
			
		||||
                        "awq" not in self.model_config.model.lower() and \
 | 
			
		||||
                        "qwen3" not in self.model_config.model.lower():
 | 
			
		||||
                    self.model.apply(padding_mlp)
 | 
			
		||||
                from ipex_llm import optimize_model
 | 
			
		||||
                not_convert_last_mlp = os.getenv("IPEX_LLM_NOT_CONVERT_LAST_MLP", None)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue