From 5df03ced2ca7aaa1844ea575ae17735332a72956 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Mon, 12 May 2025 10:54:22 +0800 Subject: [PATCH] Update vllm patch for fix telechat2 and baichuan2 error(#13150) --- .../xpu/docker/vllm_for_multi_arc.patch | 550 +++++++++++++----- 1 file changed, 409 insertions(+), 141 deletions(-) diff --git a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch index f2fc7a07..aa898fc4 100644 --- a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch +++ b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch @@ -7078,61 +7078,61 @@ index 000000000..93c64d759 --- /dev/null +++ b/csrc/xpu/reduction_utils.h @@ -0,0 +1,56 @@ -+/* -+ * Copyright (c) 2023, The vLLM team. -+ * -+ * Licensed under the Apache License, Version 2.0 (the "License"); -+ * you may not use this file except in compliance with the License. -+ * You may obtain a copy of the License at -+ * -+ * http://www.apache.org/licenses/LICENSE-2.0 -+ * -+ * Unless required by applicable law or agreed to in writing, software -+ * distributed under the License is distributed on an "AS IS" BASIS, -+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+ * See the License for the specific language governing permissions and -+ * limitations under the License. -+ */ -+#pragma once -+ -+#include -+#include -+#include -+ -+namespace vllm { -+ -+template -+__inline__ T warpReduceSum(T val, const sycl::nd_item<3>& item_ct1) { -+#pragma unroll -+ for (int mask = 16; mask > 0; mask >>= 1) -+ val += dpct::permute_sub_group_by_xor( -+ item_ct1.get_sub_group(), val, mask, 32); -+ return val; -+} -+ -+/* Calculate the sum of all elements in a block */ -+template -+__inline__ T blockReduceSum(T val, const sycl::nd_item<3> &item_ct1, T *shared) { -+ -+ int lane = item_ct1.get_local_id(2) & 0x1f; -+ int wid = item_ct1.get_local_id(2) >> 5; -+ -+ val = warpReduceSum(val, item_ct1); -+ -+ if (lane == 0) { -+ shared[wid] = val; -+ } -+ item_ct1.barrier(); -+ -+ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent -+ // blockDim.x is not divided by 32 -+ val = (item_ct1.get_local_id(2) < (item_ct1.get_local_range(2) / 32.f)) -+ ? shared[lane] -+ : (T)(0.0f); -+ val = warpReduceSum(val, item_ct1); -+ return val; -+} -+ ++/* ++ * Copyright (c) 2023, The vLLM team. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++#pragma once ++ ++#include ++#include ++#include ++ ++namespace vllm { ++ ++template ++__inline__ T warpReduceSum(T val, const sycl::nd_item<3>& item_ct1) { ++#pragma unroll ++ for (int mask = 16; mask > 0; mask >>= 1) ++ val += dpct::permute_sub_group_by_xor( ++ item_ct1.get_sub_group(), val, mask, 32); ++ return val; ++} ++ ++/* Calculate the sum of all elements in a block */ ++template ++__inline__ T blockReduceSum(T val, const sycl::nd_item<3> &item_ct1, T *shared) { ++ ++ int lane = item_ct1.get_local_id(2) & 0x1f; ++ int wid = item_ct1.get_local_id(2) >> 5; ++ ++ val = warpReduceSum(val, item_ct1); ++ ++ if (lane == 0) { ++ shared[wid] = val; ++ } ++ item_ct1.barrier(); ++ ++ // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent ++ // blockDim.x is not divided by 32 ++ val = (item_ct1.get_local_id(2) < (item_ct1.get_local_range(2) / 32.f)) ++ ? shared[lane] ++ : (T)(0.0f); ++ val = warpReduceSum(val, item_ct1); ++ return val; ++} ++ +} // namespace vllm \ No newline at end of file diff --git a/csrc/xpu/utils.cpp b/csrc/xpu/utils.cpp @@ -8692,7 +8692,7 @@ index 000000000..e98db9b65 + tensor_parallel_size=1, + ) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py -index c3d210c27..c3b6ca7eb 100644 +index c3d210c27..8dd101608 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,6 +1,4 @@ @@ -8780,10 +8780,10 @@ index c3d210c27..c3b6ca7eb 100644 + # todo: ipex will refactor namespace + import vllm._C.ops + vllm._C.ops.paged_attention_v1(out, query, -+ key_cache.view_as(value_cache), -+ value_cache, num_kv_heads, scale, -+ block_tables, context_lens, block_size, -+ max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap) ++ key_cache.view_as(value_cache), ++ value_cache, num_kv_heads, scale, ++ block_tables, context_lens, block_size, ++ max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap) @staticmethod def paged_attention_v2( @@ -8929,7 +8929,7 @@ index c3d210c27..c3b6ca7eb 100644 @staticmethod def varlen_attention( -@@ -220,22 +262,233 @@ class ipex_ops: +@@ -220,22 +262,250 @@ class ipex_ops: kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -9044,30 +9044,47 @@ index c3d210c27..c3b6ca7eb 100644 + p_dropout: float, + softmax_scale: float, + zero_tensors: bool, -+ is_caual: bool, ++ is_casual: bool, + return_softmax: bool, + gen_: Optional[torch.Generator], + ): -+ return torch.ops.torch_ipex.chunked_prefill( ++ return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( ++ output, + query.contiguous(), + key_cache, + value_cache, -+ output, + cu_seqlens_q, + cu_seqlens_k, -+ seq_used_k, -+ block_table, -+ alibi_slopes, + max_seqlen_q, + max_seqlen_k, -+ p_dropout, + softmax_scale, -+ zero_tensors, -+ is_caual, -+ return_softmax, -+ gen_, ++ is_casual, ++ block_table, ++ alibi_slopes, ++ k_scale=1.0, ++ v_scale=1.0, ) - ++ # return torch.ops.torch_ipex.chunked_prefill( ++ # query.contiguous(), ++ # key_cache, ++ # value_cache, ++ # output, ++ # cu_seqlens_q, ++ # cu_seqlens_k, ++ # seq_used_k, ++ # block_table, ++ # alibi_slopes, ++ # max_seqlen_q, ++ # max_seqlen_k, ++ # p_dropout, ++ # softmax_scale, ++ # zero_tensors, ++ # is_caual, ++ # return_softmax, ++ # gen_, ++ # ) ++ ++ + @staticmethod + def copy_blocks(key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], @@ -9078,7 +9095,7 @@ index c3d210c27..c3b6ca7eb 100644 + # block_mapping, + # ) + vllm._C.cache_ops.copy_blocks(key_caches, value_caches, block_mapping) -+ + @staticmethod def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: @@ -11666,6 +11683,143 @@ index 5649cf2dd..66e30984e 100644 if isinstance(load_config.load_format, type): return load_config.load_format(load_config) +diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py +index 6a3112b5f..7e2b7c862 100644 +--- a/vllm/model_executor/models/baichuan.py ++++ b/vllm/model_executor/models/baichuan.py +@@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata + from vllm.sequence import IntermediateTensors + + from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +-from .utils import (AutoWeightsLoader, is_pp_missing_parameter, ++from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +@@ -321,45 +321,6 @@ class BaiChuanModel(nn.Module): + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + +- def load_weights(self, weights: Iterable[Tuple[str, +- torch.Tensor]]) -> Set[str]: +- stacked_params_mapping = [ +- # (param_name, shard_name, shard_id) +- ("gate_up_proj", "gate_proj", 0), +- ("gate_up_proj", "up_proj", 1), +- ] +- params_dict = dict(self.named_parameters()) +- loaded_params: Set[str] = set() +- for name, loaded_weight in weights: +- if "rotary_emb.inv_freq" in name: +- continue +- +- for (param_name, weight_name, shard_id) in stacked_params_mapping: +- if weight_name not in name: +- continue +- name = name.replace(weight_name, param_name) +- # Skip loading extra bias for GPTQ models. +- if name.endswith(".bias") and name not in params_dict: +- continue +- if is_pp_missing_parameter(name, self): +- continue +- param = params_dict[name] +- weight_loader = param.weight_loader +- weight_loader(param, loaded_weight, shard_id) +- break +- else: +- # Skip loading extra bias for GPTQ models. +- if name.endswith(".bias") and name not in params_dict: +- continue +- if is_pp_missing_parameter(name, self): +- continue +- param = params_dict[name] +- weight_loader = getattr(param, "weight_loader", +- default_weight_loader) +- weight_loader(param, loaded_weight) +- loaded_params.add(name) +- return loaded_params +- + + class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + SupportsQuant): +@@ -392,7 +353,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) +- self.lm_head.weight.weight_loader = self.lm_head_weight_loader + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) +@@ -433,22 +393,53 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: +- loader = AutoWeightsLoader(self) +- return loader.load_weights(weights) +- +- def lm_head_weight_loader(self, param: nn.Parameter, +- loaded_weight: torch.Tensor): +- # Unlike Baichuan, Baichuan2 normalizes the head weights. +- # Refer to: +- # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 +- # Distinguish between Baichuan and Baichuan2 by checking the +- # vocab size. This is suggested by +- # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 +- is_baichuan2 = self.config.vocab_size == 125696 +- if is_baichuan2: +- loaded_weight = torch.nn.functional.normalize(loaded_weight) +- +- default_weight_loader(param, loaded_weight) ++ stacked_params_mapping = [ ++ # (param_name, shard_name, shard_id) ++ ("gate_up_proj", "gate_proj", 0), ++ ("gate_up_proj", "up_proj", 1), ++ ] ++ params_dict = dict(self.named_parameters()) ++ loaded_params: Set[str] = set() ++ for name, loaded_weight in weights: ++ if "rotary_emb.inv_freq" in name: ++ continue ++ if name == "lm_head.weight": ++ # Unlike Baichuan, Baichuan2 normalizes the head weights. ++ # Refer to: ++ # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 ++ # Distinguish between Baichuan and Baichuan2 by checking the ++ # vocab size. This is suggested by ++ # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 ++ is_baichuan2 = self.config.vocab_size == 125696 ++ if is_baichuan2: ++ loaded_weight = torch.nn.functional.normalize( ++ loaded_weight) ++ ++ for (param_name, weight_name, shard_id) in stacked_params_mapping: ++ if weight_name not in name: ++ continue ++ name = name.replace(weight_name, param_name) ++ # Skip loading extra bias for GPTQ models. ++ if name.endswith(".bias") and name not in params_dict: ++ continue ++ if is_pp_missing_parameter(name, self): ++ continue ++ param = params_dict[name] ++ weight_loader = param.weight_loader ++ weight_loader(param, loaded_weight, shard_id) ++ break ++ else: ++ # Skip loading extra bias for GPTQ models. ++ if name.endswith(".bias") and name not in params_dict: ++ continue ++ if is_pp_missing_parameter(name, self): ++ continue ++ param = params_dict[name] ++ weight_loader = getattr(param, "weight_loader", ++ default_weight_loader) ++ weight_loader(param, loaded_weight) ++ loaded_params.add(name) ++ return loaded_params + + + class BaichuanForCausalLM(BaiChuanBaseForCausalLM): diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1b1738f88..2c2ed67b9 100644 --- a/vllm/model_executor/models/chatglm.py @@ -14147,7 +14301,7 @@ index c0a3c59ba..8614c2273 100644 "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), # [Encoder-decoder] diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py -index cecad9e89..df4cf4776 100644 +index cecad9e89..7eaabd1db 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -140,6 +140,74 @@ class SiglipVisionEmbeddings(nn.Module): @@ -14195,9 +14349,9 @@ index cecad9e89..df4cf4776 100644 + + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) -+ from ipex_llm.transformers.models.utils import use_sdp_causal + from vllm.attention.backends.ipex_attn import use_sdp_causal + import xe_addons, math ++ from vllm.attention.backends.abstract import AttentionType + mask = None + scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale + from ipex_llm.transformers.models.common import padding_qkv_hd @@ -14209,7 +14363,7 @@ index cecad9e89..df4cf4776 100644 + query, key, value, + self.head_size, num + ) -+ if use_sdp_causal(query.shape[-1], query, 0): ++ if use_sdp_causal(query.shape[-1], query, 0, AttentionType.DECODER): + out = xe_addons.sdp_non_causal(query.contiguous(), key.contiguous(), value.contiguous(), mask, scale)[:, :, :, :self.head_size].transpose(1, 2) + # import torch.nn.functional as F + # out = F.scaled_dot_product_attention(query, @@ -14239,10 +14393,23 @@ index cecad9e89..df4cf4776 100644 def forward( self, diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py -index a38035e37..9631fbd83 100644 +index a38035e37..570f2bcdd 100644 --- a/vllm/model_executor/models/telechat2.py +++ b/vllm/model_executor/models/telechat2.py -@@ -44,9 +44,9 @@ class TeleChat2Model(LlamaModel): +@@ -22,10 +22,12 @@ + from typing import Iterable, Set, Tuple + + import torch ++import torch.nn as nn + + from vllm.config import VllmConfig + from vllm.model_executor.model_loader.weight_utils import default_weight_loader + from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel ++from .llama import LlamaDecoderLayer + + from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter) +@@ -44,9 +46,9 @@ class TeleChat2Model(LlamaModel): for layer in self.layers: if not isinstance(layer, PPMissingLayer): layer.self_attn.qkv_proj.bias = None @@ -14254,6 +14421,18 @@ index a38035e37..9631fbd83 100644 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: +@@ -120,7 +122,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM): + }, + ) + +- def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): ++ def _init_model(self, ++ vllm_config: VllmConfig, ++ prefix: str = "", ++ layer_type: type[nn.Module] = LlamaDecoderLayer): + return TeleChat2Model(vllm_config=vllm_config, prefix=prefix) + + def load_weights(self, weights: Iterable[Tuple[str, diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index fc0fb8929..6454e7006 100644 --- a/vllm/multimodal/utils.py @@ -14319,7 +14498,7 @@ index b6f6029de..b90fea9fd 100644 def is_neuron(self) -> bool: return self._enum == PlatformEnum.NEURON diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py -index 225e756cd..4fd7fe220 100644 +index 225e756cd..25b83549a 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional @@ -14330,7 +14509,17 @@ index 225e756cd..4fd7fe220 100644 from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum, _Backend -@@ -33,8 +34,13 @@ class XPUPlatform(Platform): +@@ -25,6 +26,9 @@ class XPUPlatform(Platform): + # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501 + ray_device_key: str = "GPU" + device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR" ++ additional_env_vars: list[str] = [ ++ "IPEX_LLM_LOWBIT", ++ ] + + @classmethod + def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, +@@ -33,8 +37,13 @@ class XPUPlatform(Platform): use_mla: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) @@ -14346,7 +14535,7 @@ index 225e756cd..4fd7fe220 100644 @staticmethod def get_device_capability( -@@ -63,6 +69,8 @@ class XPUPlatform(Platform): +@@ -63,6 +72,8 @@ class XPUPlatform(Platform): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config @@ -14355,7 +14544,7 @@ index 225e756cd..4fd7fe220 100644 if cache_config and cache_config.block_size is None: cache_config.block_size = 16 -@@ -87,31 +95,46 @@ class XPUPlatform(Platform): +@@ -87,31 +98,46 @@ class XPUPlatform(Platform): raise NotImplementedError( "XPU does not support speculative decoding") @@ -14412,6 +14601,15 @@ index 225e756cd..4fd7fe220 100644 @classmethod def is_pin_memory_available(cls): +@@ -140,3 +166,7 @@ class XPUPlatform(Platform): + @classmethod + def get_device_communicator_cls(cls) -> str: + return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa ++ ++ @classmethod ++ def use_all_gather(cls) -> bool: ++ return False +\ No newline at end of file diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 53699341b..6bc039068 100644 --- a/vllm/transformers_utils/configs/__init__.py @@ -14432,6 +14630,35 @@ index 53699341b..6bc039068 100644 "ChatGLMConfig", "Cohere2Config", "DbrxConfig", +diff --git a/vllm/utils.py b/vllm/utils.py +index 5f32f8cb6..2ee0c1906 100644 +--- a/vllm/utils.py ++++ b/vllm/utils.py +@@ -128,6 +128,8 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = { + "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, + } + ++BMG_TARGET_IDS = ["0xe20b", "0xe210"] ++ + # Constants related to forcing the attention backend selection + + # String name of register which may be set in order to +@@ -2564,3 +2566,14 @@ def sha256(input) -> int: + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return int.from_bytes(hashlib.sha256(input_bytes).digest(), + byteorder="big") ++ ++@cache ++def is_bmg_platform(): ++ if not torch.xpu.is_available(): ++ raise ValueError("Cannot detect the usage of XPU!") ++ device_index = torch.xpu.current_device() ++ device_name = torch.xpu.get_device_name(device_index) ++ for target_id in BMG_TARGET_IDS: ++ if target_id in device_name: ++ return True ++ return False +\ No newline at end of file diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c271f438e..cf7180606 100755 --- a/vllm/v1/attention/backends/flash_attn.py @@ -14457,10 +14684,10 @@ index c271f438e..cf7180606 100755 assert sliding_window == (-1, -1), ( diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py new file mode 100644 -index 000000000..29cde02f3 +index 000000000..f4a435eaa --- /dev/null +++ b/vllm/v1/attention/backends/ipex_attn.py -@@ -0,0 +1,358 @@ +@@ -0,0 +1,392 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + @@ -14474,6 +14701,7 @@ index 000000000..29cde02f3 +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.attention.backends.ipex_attn import use_gqa_kernel ++from vllm.utils import is_bmg_platform +import os + +@dataclass @@ -14509,9 +14737,9 @@ index 000000000..29cde02f3 + # if block_size % 16 != 0: + # raise ValueError("Block size must be a multiple of 16.") + # This needs to be changed... -+ # return (2, num_blocks, block_size, num_kv_heads, head_size) -+ return PagedAttention.get_kv_cache_shape(num_blocks, block_size, -+ num_kv_heads, head_size) ++ return (2, num_blocks, block_size, num_kv_heads, head_size) ++ # return PagedAttention.get_kv_cache_shape(num_blocks, block_size, ++ # num_kv_heads, head_size) + + + @@ -14557,6 +14785,8 @@ index 000000000..29cde02f3 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes() ++ self.using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap) ++ self.is_bmg_platform = is_bmg_platform() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " @@ -14567,7 +14797,6 @@ index 000000000..29cde02f3 + "are not implemented for " + "IpexAttnBackendImpl") + -+ # TODO(gc): Refine this logic..., because of bad performance... + def forward( + self, + layer: AttentionLayer, @@ -14610,6 +14839,8 @@ index 000000000..29cde02f3 + k_scale, + v_scale, + self.scale, ++ self.using_gqa_kernel, ++ self.is_bmg_platform, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, @@ -14682,6 +14913,8 @@ index 000000000..29cde02f3 + k_scale: float, + v_scale: float, + scale: float, ++ using_gqa_kernel: bool, ++ is_bmg_platform: bool, + sliding_window: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, @@ -14700,54 +14933,82 @@ index 000000000..29cde02f3 + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + -+ using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap) -+ -+ -+ if using_gqa_kernel: -+ key_cache, value_cache = split_kv_cache_ipexllm( ++ if is_bmg_platform: ++ key_cache, value_cache = kv_cache.unbind(0) ++ ipex_ops.reshape_and_cache_flash( ++ key[:num_actual_tokens], ++ value[:num_actual_tokens], ++ key_cache, ++ value_cache, ++ attn_metadata.slot_mapping, ++ kv_cache_dtype, ++ k_scale, ++ v_scale, ++ ) ++ ipex_ops.chunked_prefill( ++ query[:num_actual_tokens].contiguous(), ++ key_cache, ++ value_cache, ++ output[:num_actual_tokens], ++ attn_metadata.query_start_loc, ++ attn_metadata.seq_start_loc, ++ None, ++ attn_metadata.block_table, ++ alibi_slopes, ++ attn_metadata.max_query_len, ++ attn_metadata.max_seq_len, ++ 0.0, ++ scale, ++ False, ++ True, ++ False, ++ None, ++ ) ++ else: ++ if using_gqa_kernel: ++ key_cache, value_cache = split_kv_cache_ipexllm( ++ kv_cache, num_kv_heads, head_size) ++ ipex_ops.reshape_and_cache_ipexllm( ++ key[:num_actual_tokens], ++ value[:num_actual_tokens], ++ key_cache, ++ value_cache, ++ attn_metadata.slot_mapping.flatten(), ++ kv_cache_dtype, ++ k_scale, ++ v_scale, ++ ) ++ else: ++ key_cache, value_cache = split_kv_cache( + kv_cache, num_kv_heads, head_size) -+ ipex_ops.reshape_and_cache_ipexllm( -+ key[:num_actual_tokens], -+ value[:num_actual_tokens], -+ key_cache, -+ value_cache, -+ attn_metadata.slot_mapping.flatten(), -+ kv_cache_dtype, -+ k_scale, -+ v_scale, -+ ) -+ else: -+ key_cache, value_cache = split_kv_cache( -+ kv_cache, num_kv_heads, head_size) -+ ipex_ops.reshape_and_cache( -+ key[:num_actual_tokens], -+ value[:num_actual_tokens], -+ key_cache, -+ value_cache, -+ attn_metadata.slot_mapping.flatten(), -+ kv_cache_dtype, -+ k_scale, -+ v_scale, -+ ) -+ # Invoke chunked prefill method... -+ import vllm._C.ops -+ assert head_size == 128 or head_size == 64 -+ value = os.environ.get('USE_CONTEXT_V1') -+ query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1] -+ seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1] -+ context_len = seq_len - query_len -+ if using_gqa_kernel: -+ # if using_gqa_kernel, then only the v1 kernel can be used -+ out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) -+ elif value is None: -+ # Otherwise, by default use v2 attention forward kernel... -+ out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item()) -+ else: -+ out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) -+ -+ # output[:num_actual_tokens] = out -+ output[:num_actual_tokens] = out.view(out.shape[0], -1) ++ ipex_ops.reshape_and_cache( ++ key[:num_actual_tokens], ++ value[:num_actual_tokens], ++ key_cache, ++ value_cache, ++ attn_metadata.slot_mapping.flatten(), ++ kv_cache_dtype, ++ k_scale, ++ v_scale, ++ ) ++ # Invoke chunked prefill method... ++ import vllm._C.ops ++ assert head_size == 128 or head_size == 64 ++ value = os.environ.get('USE_CONTEXT_V1') ++ query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1] ++ seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1] ++ context_len = seq_len - query_len ++ if using_gqa_kernel: ++ # if using_gqa_kernel, then only the v1 kernel can be used ++ out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) ++ elif value is None: ++ # Otherwise, by default use v2 attention forward kernel... ++ out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item()) ++ else: ++ out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item()) + ++ # output[:num_actual_tokens] = out ++ output[:num_actual_tokens] = out.view(out.shape[0], -1) + + + @@ -15648,10 +15909,10 @@ index 000000000..8612d3d77 + self.kv_caches) diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py new file mode 100644 -index 000000000..1bc531e28 +index 000000000..1fb0dca87 --- /dev/null +++ b/vllm/v1/worker/xpu_worker.py -@@ -0,0 +1,168 @@ +@@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Optional @@ -15685,9 +15946,16 @@ index 000000000..1bc531e28 + assert device_config.device_type == "xpu" + assert current_platform.is_xpu() + -+ def load_model(self) -> None: -+ self.model_runner.load_model() ++ import os ++ lowbit = os.getenv("IPEX_LLM_LOWBIT", None) ++ if lowbit is not None: ++ from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert ++ _ipex_llm_convert(lowbit) + ++ ++ def compile_or_warm_up_model(self) -> None: ++ pass ++ + # we provide this function due to `torch.xpu.mem_get_info()` doesn't + # return correct free_gpu_memory on intel client GPU. We need to + # calculate/estiamte it. @@ -15838,7 +16106,7 @@ index 86e6d9752..ad80bf54e 100644 @dataclass(frozen=True) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py -index 9d49b4385..67f07f5b1 100644 +index 9d49b4385..7396b0c89 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -5,8 +5,8 @@ import time @@ -16163,7 +16431,7 @@ index 9d49b4385..67f07f5b1 100644 + slot_mapping_tensor = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) -+ if need_block_table: ++ if need_block_table or "bge" in self.runner.model_config.model.lower(): + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device)