From 73198d5b80dd584de58fc3625ca0cdf78b4f8e42 Mon Sep 17 00:00:00 2001 From: Shaojun Liu <61072813+liu-shaojun@users.noreply.github.com> Date: Thu, 17 Apr 2025 16:18:22 +0800 Subject: [PATCH] Update to b17 image (#13085) * update vllm patch * fix * fix triton --------- Co-authored-by: gc-fu --- docker/llm/serving/xpu/docker/Dockerfile | 4 + .../xpu/docker/vllm_for_multi_arc.patch | 968 +++++++++++++++--- 2 files changed, 850 insertions(+), 122 deletions(-) diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index fb346332..faa5fd70 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -168,6 +168,10 @@ RUN set -eux && \ mkdir -p /llm && \ cd /llm && \ rm -rf /tmp/neo && \ + # Install intel_extension_for_pytorch + pip install intel-extension-for-pytorch==2.6.10+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ && \ + pip uninstall -y oneccl oneccl-devel && \ + pip install intel-opencl-rt==2025.0.2 intel-openmp==2025.0.2 && \ # # Install vllm git clone -b v0.6.6.post1 https://github.com/vllm-project/vllm /llm/vllm && \ diff --git a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch index 7e4c62d6..e2121a65 100644 --- a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch +++ b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch @@ -1964,10 +1964,10 @@ index 40430dae1..76efeda6c 100644 if (GPU_LANGUAGE STREQUAL "CUDA") diff --git a/cmake/xpu_extension.cmake b/cmake/xpu_extension.cmake new file mode 100644 -index 000000000..085b6bb7d +index 000000000..a99dcd5a3 --- /dev/null +++ b/cmake/xpu_extension.cmake -@@ -0,0 +1,59 @@ +@@ -0,0 +1,61 @@ +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# @@ -2003,7 +2003,9 @@ index 000000000..085b6bb7d +set(VLLM_EXT_SRC + "csrc/xpu/activation_xpu.cpp" + "csrc/xpu/attention_xpu.cpp" ++ "csrc/xpu/attention_xpu_fp8.cpp" + "csrc/xpu/cache_ops_xpu.cpp" ++ "csrc/xpu/cache_ops_xpu_fp8.cpp" + "csrc/xpu/gemm_kernels_xpu.cpp" + "csrc/xpu/layernorm_xpu.cpp" + "csrc/xpu/pos_encoding_xpu.cpp" @@ -6547,24 +6549,358 @@ index 000000000..97d5c0c21 + query.device() + ); +} -diff --git a/csrc/xpu/cache_ops_xpu.cpp b/csrc/xpu/cache_ops_xpu.cpp +diff --git a/csrc/xpu/attention_xpu_fp8.cpp b/csrc/xpu/attention_xpu_fp8.cpp new file mode 100644 -index 000000000..381795cda +index 000000000..a2ea5819b --- /dev/null -+++ b/csrc/xpu/cache_ops_xpu.cpp -@@ -0,0 +1,575 @@ ++++ b/csrc/xpu/attention_xpu_fp8.cpp +@@ -0,0 +1,324 @@ +// clang-format off +#ifdef VLLM_DEV +#undef __SYCL_DEVICE_ONLY__ +#endif +#include +#include ++#include ++#include "kv.h" ++ ++// clang-format on ++#include ++#include ++#include ++#include "utils.h" ++#include "xpu_types.h" ++// #include "dtype_bfloat16.dp.hpp" ++#include "dtype_float16.h" ++#include "dtype_float32.h" ++#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 3 ++#include ++#endif ++ ++#include ++// #include ++ ++using namespace sycl::ext::intel::esimd; ++using AT = at::ScalarType; ++ ++template ++void gqa_1_kernel_fp8( ++ const void* query, // [num_seqs, num_heads, head_size] ++ const void* key, // [num_blocks, num_kv_heads, head_size, block_size] ++ const void* value, // [num_blocks, num_kv_heads, head_size, block_size] ++ const void* block_tables, // [num_seqs, max_num_blocks_per_seq] ++ const void* context_lens, // [num_seqs] ++ void* o_a_s, void* o_accs, const int64_t query_bsz_stride, ++ const int64_t query_head_stride, const int64_t kv_token_stride, ++ const int64_t kv_head_stride, const int64_t kv_block_stride, ++ const int64_t block_table_stride_batch, const int64_t o_a_s_bsz_stride, ++ const int64_t o_a_s_head_stride, const int64_t o_accs_bsz_stride, ++ const int64_t o_accs_head_stride, const float scale, const int block_size, ++ const int bsz, const int num_heads, const int num_kv_heads, ++ const int block_num, const at::Device& device) { ++ const int group_size = num_heads / num_kv_heads; ++ const int sub_rows = VS / group_size; ++ const int rem_rows = VS % group_size; ++ ++ const float attn_scale = scale; ++ ++ sycl::range<3> global_size(bsz, num_heads, block_num); ++ sycl::range<3> local_size(1, group_size, 1); ++ ++ auto cgf = [&](sycl::handler& handle) { ++ handle.parallel_for( ++ sycl::nd_range<3>(global_size, local_size), ++ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { ++ slm_init(); ++ ++ const int bsz_idx = item.get_global_id(0); ++ const int head_idx = item.get_global_id(1); ++ const int kv_head_idx = item.get_group(1); ++ const int tid = item.get_local_id(1); ++ const int vid = item.get_global_id(2); ++ ++ const IT* query_head = (const IT*)query + bsz_idx * query_bsz_stride + ++ head_idx * query_head_stride; ++ ++ IT* o_accs_head = (IT*)o_accs + bsz_idx * o_accs_bsz_stride + ++ head_idx * o_accs_head_stride; ++ float* o_a_s_head = (float*)o_a_s + bsz_idx * o_a_s_bsz_stride + ++ head_idx * o_a_s_head_stride; ++ ++ const int* block_tables_ptr = (const int*)block_tables; ++ const int* block_table = ++ block_tables_ptr + bsz_idx * block_table_stride_batch; ++ ++ const int* context_lens_ptr = (const int*)context_lens; ++ const int context_length = context_lens_ptr[bsz_idx]; ++ ++ simd query_row = block_load(query_head) * attn_scale; ++ ++ // copy k_cache to slm ++ int start_row = ++ std::min(vid * VS + tid * sub_rows + std::min(tid, rem_rows), ++ context_length); ++ int end_row = ++ std::min(start_row + sub_rows + (tid < rem_rows), context_length); ++ for (int r = start_row; r < end_row; ++r) { ++ int which_block = r / block_size; ++ int which_slot = r % block_size; ++ int physical_block_number = block_table[which_block]; ++ ++ // Load elements in uint8_t ++ const uint8_t* key_head = ++ (const uint8_t*)key + physical_block_number * kv_token_stride + ++ kv_head_idx * kv_head_stride + which_slot * kv_block_stride; ++ ++ simd key_row = block_load(key_head); ++ simd key_dequantized = dequantize_key_row(key_row); ++ slm_block_store((r - vid * VS) * HD * sizeof(IT), key_dequantized); ++ } ++ barrier(); ++ ++ simd attns = -sycl::detail::max_v(); ++ int row_num = ++ (vid + 1) * VS > context_length ? context_length % VS : VS; ++ // q @ k ++ for (int r = 0; r < row_num; ++r) { ++ simd key_row = slm_block_load(r * HD * sizeof(IT)); ++ float attn = sycl::ext::intel::esimd::detail::sum( ++ query_row * key_row); ++ attns[r] = attn; ++ } ++ ++ float max_attn = hmax(attns); ++ const simd attn_exp = exp(attns - max_attn); ++ barrier(); ++ ++ // copy v_cache to slm ++ for (int r = start_row; r < end_row; ++r) { ++ int which_block = r / block_size; ++ int which_slot = r % block_size; ++ int physical_block_number = block_table[which_block]; ++ ++ const uint8_t* value_head = ++ (const uint8_t*)value + physical_block_number * kv_token_stride + ++ kv_head_idx * kv_head_stride + which_slot * kv_block_stride; ++ ++ simd value_row = block_load(value_head); ++ simd value_dequantized = dequantize_value_row(value_row); ++ slm_block_store((r - vid * VS) * HD * sizeof(IT), ++ value_dequantized); ++ } ++ barrier(); ++ ++ // attn @ v ++ simd accs = 0; ++ for (int r = 0; r < row_num; ++r) { ++ simd value_row = ++ slm_block_load(r * HD * sizeof(IT)); ++ accs = accs + value_row * attn_exp[r]; ++ } ++ ++ float softmax = ++ sycl::ext::intel::esimd::detail::sum(attn_exp); ++ ++ block_store(o_accs_head + vid * HD, accs); ++ block_store(o_a_s_head + vid * 2, max_attn); ++ block_store(o_a_s_head + vid * 2 + 1, softmax); ++ }); ++ }; ++ ++ utils::submit_kernel(cgf, device, "gqa kernel 1/2"); ++} ++ ++template ++void gqa_2_kernel_fp8(void* o_a_s, void* o_accs, void* output, ++ const void* context_lens, // [num_seqs] ++ const int64_t o_a_s_bsz_stride, ++ const int64_t o_a_s_head_stride, ++ const int64_t o_accs_bsz_stride, ++ const int64_t o_accs_head_stride, ++ const int64_t output_bsz_stride, ++ const int64_t output_head_stride, const int bsz, ++ const int num_heads, const int row_block_num, ++ const at::Device& device) { ++ constexpr int SUB_HD = 8; ++ static_assert(HD % SUB_HD == 0); ++ static_assert(HD / SUB_HD <= GS); ++ ++ const int sub_rows = row_block_num / GS; ++ const int rem_rows = row_block_num % GS; ++ ++ constexpr int accs_slm_offset = 0; ++ constexpr int attn_slm_offset = GS * HD * sizeof(float); ++ constexpr int softmax_slm_offset = attn_slm_offset + GS * sizeof(float); ++ ++ sycl::range<3> global_size(bsz, num_heads, GS); ++ sycl::range<3> local_size(1, 1, GS); ++ ++ auto cgf = [&](sycl::handler& handle) { ++ handle.parallel_for( ++ sycl::nd_range<3>(global_size, local_size), ++ [=](sycl::nd_item<3> item) SYCL_ESIMD_KERNEL { ++ slm_init(); ++ ++ const int bsz_idx = item.get_global_id(0); ++ const int head_idx = item.get_global_id(1); ++ const int tid = item.get_global_id(2); ++ ++ const int* context_lens_ptr = (const int*)context_lens; ++ const int context_length = context_lens_ptr[bsz_idx]; ++ constexpr int VS = 32; ++ const int cur_row_block_num = (context_length + VS - 1) / VS; ++ const int cur_sub_rows = cur_row_block_num / GS; ++ const int cur_rem_rows = cur_row_block_num % GS; ++ ++ const float* o_a_s_head = (const float*)o_a_s + ++ bsz_idx * o_a_s_bsz_stride + ++ head_idx * o_a_s_head_stride; ++ const IT* o_accs_head = (const IT*)o_accs + ++ bsz_idx * o_accs_bsz_stride + ++ head_idx * o_accs_head_stride; ++ IT* output_head = (IT*)output + bsz_idx * output_bsz_stride + ++ head_idx * output_head_stride; ++ ++ int start_row = ++ std::min(tid * cur_sub_rows + std::min(tid, cur_rem_rows), ++ cur_row_block_num); ++ int end_row = ++ std::min(start_row + cur_sub_rows + (tid < cur_rem_rows), ++ cur_row_block_num); ++ ++ float max_attn = -sycl::detail::max_v(); ++ float softmax = 0; ++ simd accs = 0; ++ for (int r = start_row; r < end_row; ++r) { ++ float sub_attn = o_a_s_head[2 * r]; ++ float sub_softmax = o_a_s_head[2 * r + 1]; ++ simd sub_accs = block_load(o_accs_head + r * HD); ++ float new_max_attn = std::max(max_attn, sub_attn); ++ float exp1 = exp(max_attn - new_max_attn); ++ float exp2 = exp(sub_attn - new_max_attn); ++ accs = accs * exp1 + sub_accs * exp2; ++ softmax = softmax * exp1 + sub_softmax * exp2; ++ max_attn = new_max_attn; ++ } ++ ++ slm_block_store(accs_slm_offset + tid * HD * sizeof(float), ++ accs); ++ slm_block_store(attn_slm_offset + tid * sizeof(float), ++ max_attn); ++ slm_block_store(softmax_slm_offset + tid * sizeof(float), ++ softmax); ++ barrier(); ++ ++ if (tid < HD / SUB_HD) { ++ simd max_attns = ++ slm_block_load(attn_slm_offset); ++ const simd scales = ++ exp(max_attns - hmax(max_attns)); ++ simd softmaxs = ++ slm_block_load(softmax_slm_offset); ++ float softmax_sum = ++ sycl::ext::intel::esimd::detail::sum( ++ softmaxs * scales); ++ ++ simd result = 0; ++#pragma unroll ++ for (int r = 0; r < GS; ++r) { ++ simd sub_accs = slm_block_load( ++ accs_slm_offset + (r * HD + tid * SUB_HD) * sizeof(float)); ++ result = result + sub_accs * scales[r]; ++ } ++ result = result / softmax_sum; ++ block_store(output_head + tid * SUB_HD, result); ++ } ++ }); ++ }; ++ ++ utils::submit_kernel(cgf, device, "gqa kernel 2/2"); ++} ++ ++template ++auto dispatch_gqa_kernel_fp8(AT it) { ++ switch (it) { ++ case AT::Float: ++ return std::make_tuple(gqa_1_kernel_fp8, ++ gqa_2_kernel_fp8); ++ case AT::Half: ++ return std::make_tuple(gqa_1_kernel_fp8, ++ gqa_2_kernel_fp8); ++ default: ++ throw std::runtime_error( ++ "unsupported dtype, only fp32 and fp16 are supported"); ++ } ++} ++ ++void paged_attention_gqa_fp8(torch::Tensor output, torch::Tensor query, ++ torch::Tensor key_cache, torch::Tensor value_cache, ++ int64_t bsz, int64_t num_heads, int64_t num_kv_heads, ++ float scale, torch::Tensor& block_tables, ++ torch::Tensor& context_lens, int block_size, ++ int64_t head_dim, int max_seq_len) { ++ constexpr int VS = 32; ++ constexpr int GS = 32; ++ ++ const int row_block_num = (max_seq_len + VS - 1) / VS; ++ auto o_a_s = ++ torch::empty({bsz, num_heads, 1, row_block_num * 2}, ++ torch::device(query.device()).dtype(torch::kFloat32)); ++ auto o_accs = ++ torch::empty({bsz, num_heads, 1, row_block_num * head_dim}, ++ torch::device(query.device()).dtype(query.dtype())); ++ ++ auto [func1, func2] = [&]() { ++ switch (head_dim) { ++ case 128: ++ return dispatch_gqa_kernel_fp8(query.scalar_type()); ++ case 96: ++ return dispatch_gqa_kernel_fp8(query.scalar_type()); ++ case 80: ++ return dispatch_gqa_kernel_fp8(query.scalar_type()); ++ case 64: ++ return dispatch_gqa_kernel_fp8(query.scalar_type()); ++ default: ++ throw std::runtime_error( ++ "unsupported head_dim, only 128, 96, 80 and 64 are supported"); ++ } ++ }(); ++ ++ func1(query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), ++ block_tables.data_ptr(), context_lens.data_ptr(), o_a_s.data_ptr(), ++ o_accs.data_ptr(), query.stride(0), query.stride(1), ++ key_cache.stride(0), key_cache.stride(1), key_cache.stride(2), ++ block_tables.stride(0), o_a_s.stride(0), o_a_s.stride(1), ++ o_accs.stride(0), o_accs.stride(1), scale, block_size, bsz, num_heads, ++ num_kv_heads, row_block_num, query.device()); ++ ++ func2(o_a_s.data_ptr(), o_accs.data_ptr(), output.data_ptr(), ++ context_lens.data_ptr(), o_a_s.stride(0), o_a_s.stride(1), ++ o_accs.stride(0), o_accs.stride(1), output.stride(0), output.stride(1), ++ bsz, num_heads, row_block_num, query.device()); ++} +diff --git a/csrc/xpu/cache_ops_xpu.cpp b/csrc/xpu/cache_ops_xpu.cpp +new file mode 100644 +index 000000000..a3451c0e7 +--- /dev/null ++++ b/csrc/xpu/cache_ops_xpu.cpp +@@ -0,0 +1,579 @@ ++// clang-format off ++#ifdef VLLM_DEV ++#undef __SYCL_DEVICE_ONLY__ ++#endif ++#include ++#include ++#include +// clang-format on +#include "xpu_types.h" + +#include +#include "utils.h" + ++using fp16 = sycl::half; ++using namespace sycl::ext::intel::esimd; ++ +template +void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] @@ -7128,6 +7464,182 @@ index 000000000..381795cda + key, value, key_cache, value_cache, slot_mapping); + }); +} +diff --git a/csrc/xpu/cache_ops_xpu_fp8.cpp b/csrc/xpu/cache_ops_xpu_fp8.cpp +new file mode 100644 +index 000000000..cbfb7eea1 +--- /dev/null ++++ b/csrc/xpu/cache_ops_xpu_fp8.cpp +@@ -0,0 +1,170 @@ ++// clang-format off ++#ifdef VLLM_DEV ++#undef __SYCL_DEVICE_ONLY__ ++#endif ++#include ++#include ++#include ++// clang-format on ++#include "xpu_types.h" ++ ++#include ++#include "utils.h" ++#include "kv.h" ++ ++using fp16 = sycl::half; ++using namespace sycl::ext::intel::esimd; ++ ++// scalar_t is key.scalar_type() -> half ++template ++void reshape_and_cache_ipexllm_kernel_fp8( ++ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] ++ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] ++ uint8_t * __restrict__ key_cache, // [num_blocks, num_kv_heads, block_size, ++ // head_size] ++ uint8_t * __restrict__ value_cache, // [num_blocks, num_kv_heads, ++ // block_size, head_size] ++ const int64_t* __restrict__ slot_mapping, // [num_tokens] ++ const int key_stride, const int value_stride, ++ const int key_head_stride, const int value_head_stride, ++ const int num_heads, ++ const int head_size, const int block_size, const int x, ++ const sycl::nd_item<3>& item_ct1) { ++ ++ // New Implementation // ++ const size_t token_idx = item_ct1.get_global_id(0); ++ const size_t head_idx = item_ct1.get_global_id(1); ++ const int64_t slot_idx = slot_mapping[token_idx]; ++ if (slot_idx < 0) { ++ return; ++ } ++ const int64_t block_idx = slot_idx / block_size; ++ const int64_t block_offset = slot_idx % block_size; ++ // The thread is responsible for the HD elements within key/value ++ const scalar_t * key_head = key + token_idx * key_stride + head_idx * key_head_stride; ++ ++ const scalar_t * value_head = value + token_idx * value_stride + head_idx * value_head_stride; ++ ++ uint8_t * key_output_head = key_cache + block_idx * num_heads * head_size * block_size + ++ head_idx * head_size * block_size + block_offset * head_size; ++ uint8_t * value_output_head = value_cache + block_idx * num_heads * head_size * block_size + ++ head_idx * head_size * block_size + block_offset * head_size; ++ ++ simd key_row = block_load(key_head); ++ simd key_result = quantize_key_row(key_row); ++ block_store(key_output_head, key_result); ++ ++ simd value_row = block_load(value_head); ++ simd value_result = quantize_value_row(value_row); ++ block_store(value_output_head, value_result); ++} ++ ++ ++template ++void call_reshape_and_cache_ipexllm_kernel_fp8( ++ const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, ++ uint8_t* __restrict__ key_cache, uint8_t* __restrict__ value_cache, ++ const int64_t* __restrict__ slot_mapping, const int num_tokens, ++ const int key_stride, const int value_stride, ++ const int key_head_stride, const int value_head_stride, ++ const int num_heads, ++ const int head_size, const int block_size, const int x) { ++ using sycl_t = vllm::xpu::SyclTypeTrait::Type; ++ sycl::range<3> grid(num_tokens, num_heads, 1); ++ sycl::range<3> block(1, 1, 1); ++ auto& queue = vllm::xpu::vllmGetQueue(); ++ queue.submit([&](sycl::handler& cgh) { ++ cgh.parallel_for( ++ sycl::nd_range<3>(grid * block, block), [=](sycl::nd_item<3> item_ct1) SYCL_ESIMD_KERNEL { ++ reshape_and_cache_ipexllm_kernel_fp8( ++ (const sycl_t* __restrict__)key, ++ (const sycl_t* __restrict__)value, ++ (uint8_t* __restrict__)key_cache, ++ (uint8_t* __restrict__)value_cache, slot_mapping, key_stride, ++ value_stride, key_head_stride, value_head_stride, ++ num_heads, head_size, block_size, x, item_ct1); ++ }); ++ }); ++} ++ ++void reshape_and_cache_ipexllm_fp8(torch::Tensor& key, torch::Tensor& value, ++ torch::Tensor& key_cache, ++ torch::Tensor& value_cache, ++ torch::Tensor& slot_mapping, ++ const std::string& kv_cache_dtype, ++ const float kv_scale) { ++ int num_tokens = key.size(0); ++ int num_heads = key.size(1); ++ int head_size = key.size(2); ++ int block_size = key_cache.size(2); ++ // int x = key_cache.size(4); ++ int x = 1; ++ ++ int key_stride = key.stride(0); ++ int value_stride = value.stride(0); ++ ++ int key_head_stride = key.stride(1); ++ int value_head_stride = value.stride(1); ++ ++ // This actually dispatches on scalar_type, we will then need to dispatch on Head Dim... ++switch (head_size) { ++ case 64: ++ VLLM_XPU_DISPATCH_FLOATING_TYPES( ++ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { ++ call_reshape_and_cache_ipexllm_kernel_fp8( ++ key.data_ptr(), value.data_ptr(), ++ key_cache.data_ptr(), value_cache.data_ptr(), ++ slot_mapping.data_ptr(), num_tokens, key_stride, ++ value_stride, key_head_stride, value_head_stride, num_heads, ++ head_size, block_size, x); ++ }); ++ break; ++ case 128: ++ VLLM_XPU_DISPATCH_FLOATING_TYPES( ++ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { ++ call_reshape_and_cache_ipexllm_kernel_fp8( ++ key.data_ptr(), value.data_ptr(), ++ key_cache.data_ptr(), value_cache.data_ptr(), ++ slot_mapping.data_ptr(), num_tokens, key_stride, ++ value_stride, key_head_stride, value_head_stride, num_heads, ++ head_size, block_size, x); ++ }); ++ break; ++ case 96: ++ VLLM_XPU_DISPATCH_FLOATING_TYPES( ++ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { ++ call_reshape_and_cache_ipexllm_kernel_fp8( ++ key.data_ptr(), value.data_ptr(), ++ key_cache.data_ptr(), value_cache.data_ptr(), ++ slot_mapping.data_ptr(), num_tokens, key_stride, ++ value_stride, key_head_stride, value_head_stride, num_heads, ++ head_size, block_size, x); ++ }); ++ break; ++ case 80: ++ VLLM_XPU_DISPATCH_FLOATING_TYPES( ++ key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { ++ call_reshape_and_cache_ipexllm_kernel_fp8( ++ key.data_ptr(), value.data_ptr(), ++ key_cache.data_ptr(), value_cache.data_ptr(), ++ slot_mapping.data_ptr(), num_tokens, key_stride, ++ value_stride, key_head_stride, value_head_stride, num_heads, ++ head_size, block_size, x); ++ }); ++ break; ++ default: ++ TORCH_CHECK(false, "Unsupported head_dim: ", head_size); ++} ++ // VLLM_XPU_DISPATCH_FLOATING_TYPES( ++ // key.scalar_type(), "call_reshape_and_cache_ipexllm_kernel_fp8", [&] { ++ // call_reshape_and_cache_ipexllm_kernel_fp8( ++ // key.data_ptr(), value.data_ptr(), ++ // key_cache.data_ptr(), value_cache.data_ptr(), ++ // slot_mapping.data_ptr(), num_tokens, key_stride, ++ // value_stride, key_head_stride, value_head_stride, ++ // num_heads, head_size, block_size, x); ++ // }); ++} ++ ++ ++ diff --git a/csrc/xpu/dequantize.h b/csrc/xpu/dequantize.h new file mode 100644 index 000000000..9a967312e @@ -8081,6 +8593,89 @@ index 000000000..d96aa5880 + return _de_kernel; +} \ No newline at end of file +diff --git a/csrc/xpu/kv.h b/csrc/xpu/kv.h +new file mode 100644 +index 000000000..9616ad7ef +--- /dev/null ++++ b/csrc/xpu/kv.h +@@ -0,0 +1,76 @@ ++#pragma once ++ ++#include ++#include ++ ++using fp16 = sycl::half; ++ ++constexpr uint8_t FP16_EXP_OFFSET = 15; ++constexpr uint8_t K_EXP_OFFSET = 9; ++constexpr uint8_t V_EXP_OFFSET = 12; ++constexpr uint8_t K_OFFSET = (FP16_EXP_OFFSET - K_EXP_OFFSET) << 3; ++constexpr uint8_t V_OFFSET = (FP16_EXP_OFFSET - V_EXP_OFFSET) << 3; ++constexpr uint16_t K_MAX = ++ (uint16_t)0x3FC0 + ((uint16_t)(FP16_EXP_OFFSET - K_EXP_OFFSET) << 10); ++constexpr uint16_t K_MIN = ++ (uint16_t)0x0040 + ((uint16_t)(FP16_EXP_OFFSET - K_EXP_OFFSET) << 10); ++constexpr uint16_t V_MAX = ++ (uint16_t)0x3FC0 + ((uint16_t)(FP16_EXP_OFFSET - V_EXP_OFFSET) << 10); ++constexpr uint16_t V_MIN = ++ (uint16_t)0x0040 + ((uint16_t)(FP16_EXP_OFFSET - V_EXP_OFFSET) << 10); ++ ++template ++ESIMD_INLINE __ESIMD_NS::simd quantize_key_row( ++ __ESIMD_NS::simd key_row) { ++ const __ESIMD_NS::simd kmax = sycl::bit_cast(K_MAX); ++ const __ESIMD_NS::simd kmin = sycl::bit_cast(K_MIN); ++ __ESIMD_NS::simd key = ++ __ESIMD_NS::max(__ESIMD_NS::min(__ESIMD_NS::abs(key_row), kmax), kmin); ++ key.template bit_cast_view() <<= 1; ++ __ESIMD_NS::simd sign = ++ key_row.template bit_cast_view().template select(1) & ++ (uint8_t)0x80; ++ return (key.template bit_cast_view().template select(1) - ++ K_OFFSET) | ++ sign; ++} ++ ++template ++ESIMD_INLINE __ESIMD_NS::simd quantize_value_row( ++ __ESIMD_NS::simd value_row) { ++ const __ESIMD_NS::simd vmax = sycl::bit_cast(V_MAX); ++ const __ESIMD_NS::simd vmin = sycl::bit_cast(V_MIN); ++ __ESIMD_NS::simd value = ++ __ESIMD_NS::max(__ESIMD_NS::min(__ESIMD_NS::abs(value_row), vmax), vmin); ++ value.template bit_cast_view() <<= 1; ++ __ESIMD_NS::simd sign = ++ value_row.template bit_cast_view().template select(1) & ++ (uint8_t)0x80; ++ return (value.template bit_cast_view().template select(1) - ++ V_OFFSET) | ++ sign; ++} ++ ++template ++ESIMD_INLINE __ESIMD_NS::simd dequantize_key_row( ++ const __ESIMD_NS::simd& key_row) { ++ __ESIMD_NS::simd result = 0x80; ++ result.template bit_cast_view().template select(1) = ++ (key_row & (uint8_t)0x7F) + K_OFFSET; ++ result >>= 1; ++ __ESIMD_NS::simd sign = key_row & (uint8_t)0x80; ++ result.template bit_cast_view().template select(1) |= sign; ++ return result.template bit_cast_view(); ++} ++ ++template ++ESIMD_INLINE __ESIMD_NS::simd dequantize_value_row( ++ const __ESIMD_NS::simd& value_row) { ++ __ESIMD_NS::simd result = 0x80; ++ result.template bit_cast_view().template select(1) = ++ (value_row & (uint8_t)0x7F) + V_OFFSET; ++ result >>= 1; ++ __ESIMD_NS::simd sign = value_row & (uint8_t)0x80; ++ result.template bit_cast_view().template select(1) |= sign; ++ return result.template bit_cast_view(); ++} +\ No newline at end of file diff --git a/csrc/xpu/layernorm_xpu.cpp b/csrc/xpu/layernorm_xpu.cpp new file mode 100644 index 000000000..9a6a2af0a @@ -8617,10 +9212,10 @@ index 000000000..3232cacbc \ No newline at end of file diff --git a/csrc/xpu/pybind.cpp b/csrc/xpu/pybind.cpp new file mode 100644 -index 000000000..224419995 +index 000000000..55b29cb1e --- /dev/null +++ b/csrc/xpu/pybind.cpp -@@ -0,0 +1,94 @@ +@@ -0,0 +1,101 @@ +// #include "cache.h" +#include "xpu_ops.h" +#include @@ -8649,7 +9244,9 @@ index 000000000..224419995 + "paged_attention_gqa", + &paged_attention_gqa, + "PagedAttention GQA."); -+ ++ ++ ops.def("paged_attention_gqa_fp8", &paged_attention_gqa_fp8, "PagedAttention GQA fp8."); ++ + // Activation ops + ops.def( + "silu_and_mul", @@ -8708,6 +9305,11 @@ index 000000000..224419995 + &reshape_and_cache_ipexllm, + "Reshape the key and value tensors and cache them for ipex_llm"); + ++ cache_ops.def( ++ "reshape_and_cache_ipexllm_fp8", ++ &reshape_and_cache_ipexllm_fp8, ++ "Reshape the key and value tensors and cache them for ipex_llm with fp8"); ++ + // Quant + ops.def( + "awq_dequantize", @@ -8908,10 +9510,10 @@ index 000000000..fa3ead51c +} diff --git a/csrc/xpu/xpu_ops.h b/csrc/xpu/xpu_ops.h new file mode 100644 -index 000000000..1f71b84b2 +index 000000000..e78cc59a1 --- /dev/null +++ b/csrc/xpu/xpu_ops.h -@@ -0,0 +1,160 @@ +@@ -0,0 +1,174 @@ +#include + +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, @@ -8984,6 +9586,13 @@ index 000000000..1f71b84b2 + torch::Tensor &slot_mapping, + const std::string& kv_cache_dtype, const float kv_scale); + ++void reshape_and_cache_ipexllm_fp8(torch::Tensor& key, torch::Tensor& value, ++ torch::Tensor& key_cache, ++ torch::Tensor& value_cache, ++ torch::Tensor& slot_mapping, ++ const std::string& kv_cache_dtype, ++ const float kv_scale); ++ +void moe_align_block_size( + torch::Tensor topk_ids, + int num_experts, @@ -9072,6 +9681,14 @@ index 000000000..1f71b84b2 + int64_t head_dim, + int max_seq_len +); ++ ++void paged_attention_gqa_fp8(torch::Tensor output, torch::Tensor query, ++ torch::Tensor key_cache, torch::Tensor value_cache, ++ int64_t bsz, int64_t num_heads, int64_t num_kv_heads, ++ float scale, torch::Tensor& block_tables, ++ torch::Tensor& context_lens, int block_size, ++ int64_t head_dim, int max_seq_len); +\ No newline at end of file diff --git a/csrc/xpu/xpu_types.h b/csrc/xpu/xpu_types.h new file mode 100644 index 000000000..23f5b805c @@ -17095,7 +17712,7 @@ index aeacf5dda..eb2f69df4 100644 def register_fake(fn): return lambda name: fn diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py -index 28b804f76..6e8b5dbe2 100644 +index 28b804f76..f73ba0d3b 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,4 +1,4 @@ @@ -17104,9 +17721,18 @@ index 28b804f76..6e8b5dbe2 100644 import torch -@@ -11,6 +11,7 @@ try: - except ImportError as e: - logger.warning("Import error msg: %s", e.msg) +@@ -6,11 +6,12 @@ from vllm.logger import init_logger + + logger = init_logger(__name__) + +-try: +- import intel_extension_for_pytorch as ipex +-except ImportError as e: +- logger.warning("Import error msg: %s", e.msg) ++# try: ++# import intel_extension_for_pytorch as ipex ++# except ImportError as e: ++# logger.warning("Import error msg: %s", e.msg) +import vllm._C.ops @@ -17330,30 +17956,15 @@ index 28b804f76..6e8b5dbe2 100644 @staticmethod def varlen_attention( -@@ -185,15 +229,13 @@ class ipex_ops: +@@ -185,6 +229,7 @@ class ipex_ops: gen_: torch.Generator, logits_soft_cap: float, ) -> None: -- ipex.llm.functional.varlen_attention(query.contiguous(), -- key.contiguous(), -- value.contiguous(), out, -- seqlen_q.int(), seqlen_k.int(), -- max_seqlen_q, max_seqlen_k, -- pdropout, softmax_scale, -- zero_tensors, is_causal, -- return_softmax, gen_, -- logits_soft_cap) -+ pass -+ -+ # ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, -+ # seqlen_k, max_seqlen_q, -+ # max_seqlen_k, pdropout, -+ # softmax_scale, zero_tensors, -+ # is_causal, return_softmax, gen_) - - @staticmethod - def reshape_and_cache( -@@ -205,22 +247,180 @@ class ipex_ops: ++ import intel_extension_for_pytorch as ipex + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, +@@ -205,22 +250,233 @@ class ipex_ops: kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -17374,7 +17985,60 @@ index 28b804f76..6e8b5dbe2 100644 + k_scale: float, + v_scale: float, + ) -> None: -+ vllm._C.cache_ops.reshape_and_cache_ipexllm(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale) ++ if kv_cache_dtype == "fp8": ++ vllm._C.cache_ops.reshape_and_cache_ipexllm_fp8(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale) ++ else: ++ vllm._C.cache_ops.reshape_and_cache_ipexllm(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale) ++ ++ @staticmethod ++ def paged_attention_gqa( ++ out: torch.Tensor, ++ query: torch.Tensor, ++ key_cache: torch.Tensor, ++ value_cache: torch.Tensor, ++ batch_size: int, ++ num_heads: int, ++ num_kv_heads: int, ++ scale: float, ++ block_tables: torch.Tensor, ++ seq_lens_tensor: torch.Tensor, ++ block_size: int, ++ head_size: int, ++ max_seq_len: int, ++ kv_cache_format: str ++ ): ++ if kv_cache_format == "fp8": ++ vllm._C.ops.paged_attention_gqa_fp8( ++ out, ++ query, ++ key_cache, ++ value_cache, ++ batch_size, ++ num_heads, ++ num_kv_heads, ++ scale, ++ block_tables, ++ seq_lens_tensor, ++ block_size, ++ head_size, ++ max_seq_len ++ ) ++ else: ++ vllm._C.ops.paged_attention_gqa( ++ out, ++ query, ++ key_cache, ++ value_cache, ++ batch_size, ++ num_heads, ++ num_kv_heads, ++ scale, ++ block_tables, ++ seq_lens_tensor, ++ block_size, ++ head_size, ++ max_seq_len ++ ) + + @staticmethod + def reshape_and_cache_flash( @@ -17572,7 +18236,7 @@ index cb831cb0b..0a55506f8 100644 - return torch.load(image_path, map_location="cpu") + return torch.load(image_path, map_location="cpu", weights_only=True) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py -index 21949874b..79ed61f35 100644 +index 21949874b..7902d02e3 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -4,7 +4,7 @@ from dataclasses import dataclass @@ -17584,7 +18248,20 @@ index 21949874b..79ed61f35 100644 from vllm._ipex_ops import ipex_ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -@@ -49,18 +49,16 @@ class IpexAttnBackend(AttentionBackend): +@@ -12,7 +12,12 @@ from vllm.attention.backends.utils import CommonAttentionState + from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + ++from vllm.logger import init_logger ++logger = init_logger('vllm.attention.backends.ipex_attn') ++from vllm.utils import print_info_once, print_warning_once ++ + _PARTITION_SIZE = 512 ++_IPEX_BACKEND_SUPPORTED_KV_CACHE_FORMAT=["fp8", "auto"] + + + class IpexAttnBackend(AttentionBackend): +@@ -49,18 +54,16 @@ class IpexAttnBackend(AttentionBackend): dst_kv_cache: torch.Tensor, src_to_dst: torch.Tensor, ) -> None: @@ -17605,7 +18282,7 @@ index 21949874b..79ed61f35 100644 @dataclass -@@ -74,6 +72,11 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): +@@ -74,6 +77,11 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] seqlen_q: Optional[torch.Tensor] max_seqlen: Optional[int] @@ -17617,7 +18294,7 @@ index 21949874b..79ed61f35 100644 def __post_init__(self): # Set during the execution of the first attention op. -@@ -86,21 +89,140 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): +@@ -86,21 +94,140 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): @property def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: # Currently chunked prefill is not supported @@ -17767,7 +18444,7 @@ index 21949874b..79ed61f35 100644 class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): -@@ -134,7 +256,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -134,7 +261,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) if logits_soft_cap is None: @@ -17776,7 +18453,32 @@ index 21949874b..79ed61f35 100644 self.logits_soft_cap = logits_soft_cap supported_head_sizes = PagedAttention.get_supported_head_sizes() -@@ -153,16 +275,34 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -142,10 +269,20 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") +- if kv_cache_dtype != "auto": +- raise NotImplementedError( +- "IPEX backend does not support FP8 KV cache. " +- "Please use xFormers backend instead.") ++ if kv_cache_dtype not in _IPEX_BACKEND_SUPPORTED_KV_CACHE_FORMAT: ++ raise NotImplementedError(f"IPEX backend does not support " ++ "KV cache format {kv_cache_dtype}") ++ # Also check for gqa models... ++ self.using_gqa_kernel = use_gqa_kernel(self.num_heads, self.num_kv_heads, self.head_size, self.logits_soft_cap) ++ if not self.using_gqa_kernel and kv_cache_dtype == "fp8": ++ raise NotImplementedError(f"IPEX backend currently only supports " ++ "fp8 kv cache in group-query attention") ++ ++ self.ipex_varlen_attn = False ++ flag = os.getenv("IPEX_LLM_PREFILL_VARLEN_BACKEND", None) ++ if flag is not None: ++ self.ipex_varlen_attn = True ++ print_info_once(f"Using varlen_attention for prefilling.") + + def split_kv_cache( + self, +@@ -153,16 +290,34 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): num_kv_heads: int, head_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -17812,7 +18514,7 @@ index 21949874b..79ed61f35 100644 def forward( self, query: torch.Tensor, -@@ -200,75 +340,166 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -200,75 +355,172 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) @@ -17834,9 +18536,8 @@ index 21949874b..79ed61f35 100644 - assert attn_metadata.seq_lens is not None - if (kv_cache.numel() == 0 - or attn_metadata.block_tables.numel() == 0): -+ using_gqa_kernel = use_gqa_kernel(self.num_heads, self.num_kv_heads, self.head_size, self.logits_soft_cap) + if kv_cache is not None: -+ if using_gqa_kernel: ++ if self.using_gqa_kernel: + key_cache, value_cache = self.split_kv_cache_ipexllm( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache_ipexllm( @@ -17907,29 +18608,8 @@ index 21949874b..79ed61f35 100644 - att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, None, dtype=query.dtype) - attn_metadata.attn_bias = att_masks -+ att_masks = [None] * len(prefill_meta.seq_lens) -+ prefill_meta.attn_bias = att_masks -+ -+ # output = torch.empty( -+ # (num_tokens, self.num_heads, self.head_size), -+ # dtype=query.dtype, -+ # device=query.device) -+ # ipex_ops.varlen_attention(query, -+ # key, -+ # value, -+ # output, -+ # attn_metadata.seqlen_q, -+ # attn_metadata.seqlen_q, -+ # attn_metadata.max_seqlen, -+ # attn_metadata.max_seqlen, -+ # pdropout=0.0, -+ # softmax_scale=self.scale, -+ # zero_tensors=False, -+ # is_causal=True, -+ # return_softmax=False, -+ # gen_=None) - - output = torch.empty( +- +- output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype, - device=query.device) @@ -17950,50 +18630,80 @@ index 21949874b..79ed61f35 100644 - gen_=None, - logits_soft_cap=self.logits_soft_cap, - ) -+ (num_tokens, self.num_heads, self.head_size), -+ dtype=query.dtype, device=query.device) -+ query = query.movedim(0, query.dim() - 2) -+ key = key.movedim(0, key.dim() - 2) -+ value = value.movedim(0, value.dim() - 2) -+ import math -+ scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale -+ start = 0 -+ for seq_len, mask in zip(prefill_meta.seq_lens, -+ prefill_meta.attn_bias): -+ end = start + seq_len -+ if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap): -+ import xe_addons -+ if mask is not None: -+ mask = mask.unsqueeze(0) -+ if self.logits_soft_cap == 0 or self.head_size != 256: -+ sub_out = xe_addons.sdp_causal( -+ query[None, :, start:end, :].contiguous(), -+ key[None, :, start:end, :].contiguous(), -+ value[None, :, start:end, :].contiguous(), -+ mask, -+ scale).squeeze(0).movedim( -+ query.dim() - 2, 0) ++ att_masks = [None] * len(prefill_meta.seq_lens) ++ prefill_meta.attn_bias = att_masks ++ ++ if self.ipex_varlen_attn: ++ output = torch.empty( ++ (num_tokens, self.num_heads, self.head_size), ++ dtype=query.dtype, ++ device=query.device) ++ ++ tmp = [0] ++ tmp.extend(prefill_meta.seq_lens) ++ seqlen = torch.tensor(tmp) ++ seqlen_q = torch.cumsum(seqlen, dim=0).to(device=query.device) ++ ipex_ops.varlen_attention(query, ++ key, ++ value, ++ output, ++ seqlen_q, ++ seqlen_q, ++ prefill_meta.max_seqlen, ++ prefill_meta.max_seqlen, ++ pdropout=0.0, ++ softmax_scale=self.scale, ++ zero_tensors=False, ++ is_causal=True, ++ return_softmax=False, ++ gen_=None, ++ logits_soft_cap=self.logits_soft_cap) ++ else: ++ output = torch.empty( ++ (num_tokens, self.num_heads, self.head_size), ++ dtype=query.dtype, device=query.device) ++ query = query.movedim(0, query.dim() - 2) ++ key = key.movedim(0, key.dim() - 2) ++ value = value.movedim(0, value.dim() - 2) ++ import math ++ scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale ++ start = 0 ++ for seq_len, mask in zip(prefill_meta.seq_lens, ++ prefill_meta.attn_bias): ++ end = start + seq_len ++ if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap): ++ import xe_addons ++ if mask is not None: ++ mask = mask.unsqueeze(0) ++ if self.logits_soft_cap == 0 or self.head_size != 256: ++ sub_out = xe_addons.sdp_causal( ++ query[None, :, start:end, :].contiguous(), ++ key[None, :, start:end, :].contiguous(), ++ value[None, :, start:end, :].contiguous(), ++ mask, ++ scale).squeeze(0).movedim( ++ query.dim() - 2, 0) ++ else: ++ sub_out = xe_addons.gemma2_sdp_causal( ++ query[None, :, start:end, :].contiguous(), ++ key[None, :, start:end, :].contiguous(), ++ value[None, :, start:end, :].contiguous(), ++ mask, ++ self.logits_soft_cap, ++ self.scale).squeeze(0).movedim( ++ query.dim() - 2, 0) + else: -+ sub_out = xe_addons.gemma2_sdp_causal( -+ query[None, :, start:end, :].contiguous(), -+ key[None, :, start:end, :].contiguous(), -+ value[None, :, start:end, :].contiguous(), -+ mask, -+ self.logits_soft_cap, -+ self.scale).squeeze(0).movedim( -+ query.dim() - 2, 0) -+ else: -+ sub_out = torch.nn.functional.scaled_dot_product_attention( -+ query[None, :, start:end, :], -+ key[None, :, start:end, :], -+ value[None, :, start:end, :], -+ attn_mask=mask, -+ dropout_p=0.0, -+ is_causal=not self.need_mask, -+ scale=self.scale).squeeze(0).movedim( -+ query.dim() - 2, 0) -+ output[start:end, :, :] = sub_out -+ start = end ++ sub_out = torch.nn.functional.scaled_dot_product_attention( ++ query[None, :, start:end, :], ++ key[None, :, start:end, :], ++ value[None, :, start:end, :], ++ attn_mask=mask, ++ dropout_p=0.0, ++ is_causal=not self.need_mask, ++ scale=self.scale).squeeze(0).movedim( ++ query.dim() - 2, 0) ++ output[start:end, :, :] = sub_out ++ start = end else: # prefix-enabled attention - raise RuntimeError( @@ -18007,7 +18717,7 @@ index 21949874b..79ed61f35 100644 + import vllm._C.ops + assert self.head_size == 128 or self.head_size == 64 + value = os.environ.get('USE_CONTEXT_V1') -+ if using_gqa_kernel: ++ if self.using_gqa_kernel: + # if using_gqa_kernel, then only the v1 kernel can be used + out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) + elif value is None: @@ -18031,7 +18741,7 @@ index 21949874b..79ed61f35 100644 max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use -@@ -279,59 +510,85 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -279,59 +531,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory # shortage. @@ -18046,9 +18756,9 @@ index 21949874b..79ed61f35 100644 + bsz = len(decode_meta.seq_lens) + import vllm._C.ops + -+ if using_gqa_kernel: ++ if self.using_gqa_kernel: + block_size = value_cache.shape[2] -+ vllm._C.ops.paged_attention_gqa( ++ ipex_ops.paged_attention_gqa( + out, + decode_query, key_cache, @@ -18062,13 +18772,13 @@ index 21949874b..79ed61f35 100644 + decode_meta.block_tables, + decode_meta.seq_lens_tensor, block_size, -- max_seq_len, ++ head_size, + max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - k_scale, - v_scale, -+ head_size, -+ max_seq_len ++ self.kv_cache_dtype ) else: - # Run PagedAttention V2. @@ -18192,10 +18902,24 @@ index 350f88c8f..17ebe6ddf 100644 if IS_COMPUTE_8_OR_ABOVE: from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py -index cbc6c74ac..c95815505 100644 +index cbc6c74ac..6b7afcc66 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py -@@ -9,6 +9,7 @@ except ImportError: +@@ -1,14 +1,16 @@ + from typing import Dict, List, Optional, Tuple + +-try: +- import intel_extension_for_pytorch.llm.modules as ipex_modules +- _use_ipex = True +-except ImportError: +- _use_ipex = False ++# try: ++# import intel_extension_for_pytorch.llm.modules as ipex_modules ++# _use_ipex = True ++# except ImportError: ++# _use_ipex = False ++_use_ipex = False + import torch from vllm import _custom_ops as ops @@ -18203,7 +18927,7 @@ index cbc6c74ac..c95815505 100644 class _PagedAttention: -@@ -187,5 +188,44 @@ class _IPEXPagedAttention(_PagedAttention): +@@ -187,5 +189,44 @@ class _IPEXPagedAttention(_PagedAttention): scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes)