Update vllm patch for fix telechat2 and baichuan2 error(#13150)
This commit is contained in:
		
							parent
							
								
									9da1c56fa8
								
							
						
					
					
						commit
						5df03ced2c
					
				
					 1 changed files with 409 additions and 141 deletions
				
			
		| 
						 | 
					@ -7078,61 +7078,61 @@ index 000000000..93c64d759
 | 
				
			||||||
--- /dev/null
 | 
					--- /dev/null
 | 
				
			||||||
+++ b/csrc/xpu/reduction_utils.h
 | 
					+++ b/csrc/xpu/reduction_utils.h
 | 
				
			||||||
@@ -0,0 +1,56 @@
 | 
					@@ -0,0 +1,56 @@
 | 
				
			||||||
+/*
 | 
					+/*
 | 
				
			||||||
+ * Copyright (c) 2023, The vLLM team.
 | 
					+ * Copyright (c) 2023, The vLLM team.
 | 
				
			||||||
+ *
 | 
					+ *
 | 
				
			||||||
+ * Licensed under the Apache License, Version 2.0 (the "License");
 | 
					+ * Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
+ * you may not use this file except in compliance with the License.
 | 
					+ * you may not use this file except in compliance with the License.
 | 
				
			||||||
+ * You may obtain a copy of the License at
 | 
					+ * You may obtain a copy of the License at
 | 
				
			||||||
+ *
 | 
					+ *
 | 
				
			||||||
+ *     http://www.apache.org/licenses/LICENSE-2.0
 | 
					+ *     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
+ *
 | 
					+ *
 | 
				
			||||||
+ * Unless required by applicable law or agreed to in writing, software
 | 
					+ * Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
+ * distributed under the License is distributed on an "AS IS" BASIS,
 | 
					+ * distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
					+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
+ * See the License for the specific language governing permissions and
 | 
					+ * See the License for the specific language governing permissions and
 | 
				
			||||||
+ * limitations under the License.
 | 
					+ * limitations under the License.
 | 
				
			||||||
+ */
 | 
					+ */
 | 
				
			||||||
+#pragma once
 | 
					+#pragma once
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+#include <dpct/dpct.hpp>
 | 
					+#include <dpct/dpct.hpp>
 | 
				
			||||||
+#include <stdint.h>
 | 
					+#include <stdint.h>
 | 
				
			||||||
+#include <sycl/sycl.hpp>
 | 
					+#include <sycl/sycl.hpp>
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+namespace vllm {
 | 
					+namespace vllm {
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+template <typename T>
 | 
					+template <typename T>
 | 
				
			||||||
+__inline__ T warpReduceSum(T val, const sycl::nd_item<3>& item_ct1) {
 | 
					+__inline__ T warpReduceSum(T val, const sycl::nd_item<3>& item_ct1) {
 | 
				
			||||||
+#pragma unroll
 | 
					+#pragma unroll
 | 
				
			||||||
+  for (int mask = 16; mask > 0; mask >>= 1)
 | 
					+  for (int mask = 16; mask > 0; mask >>= 1)
 | 
				
			||||||
+    val += dpct::permute_sub_group_by_xor(
 | 
					+    val += dpct::permute_sub_group_by_xor(
 | 
				
			||||||
+        item_ct1.get_sub_group(), val, mask, 32);
 | 
					+        item_ct1.get_sub_group(), val, mask, 32);
 | 
				
			||||||
+  return val;
 | 
					+  return val;
 | 
				
			||||||
+}
 | 
					+}
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+/* Calculate the sum of all elements in a block */
 | 
					+/* Calculate the sum of all elements in a block */
 | 
				
			||||||
+template<typename T>
 | 
					+template<typename T>
 | 
				
			||||||
+__inline__ T blockReduceSum(T val, const sycl::nd_item<3> &item_ct1, T *shared) {
 | 
					+__inline__ T blockReduceSum(T val, const sycl::nd_item<3> &item_ct1, T *shared) {
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+  int lane = item_ct1.get_local_id(2) & 0x1f;
 | 
					+  int lane = item_ct1.get_local_id(2) & 0x1f;
 | 
				
			||||||
+  int wid = item_ct1.get_local_id(2) >> 5;
 | 
					+  int wid = item_ct1.get_local_id(2) >> 5;
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+  val = warpReduceSum<T>(val, item_ct1);
 | 
					+  val = warpReduceSum<T>(val, item_ct1);
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+  if (lane == 0) {
 | 
					+  if (lane == 0) {
 | 
				
			||||||
+    shared[wid] = val;
 | 
					+    shared[wid] = val;
 | 
				
			||||||
+  }
 | 
					+  }
 | 
				
			||||||
+  item_ct1.barrier();
 | 
					+  item_ct1.barrier();
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+  // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
 | 
					+  // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
 | 
				
			||||||
+  // blockDim.x is not divided by 32
 | 
					+  // blockDim.x is not divided by 32
 | 
				
			||||||
+  val = (item_ct1.get_local_id(2) < (item_ct1.get_local_range(2) / 32.f))
 | 
					+  val = (item_ct1.get_local_id(2) < (item_ct1.get_local_range(2) / 32.f))
 | 
				
			||||||
+            ? shared[lane]
 | 
					+            ? shared[lane]
 | 
				
			||||||
+            : (T)(0.0f);
 | 
					+            : (T)(0.0f);
 | 
				
			||||||
+  val = warpReduceSum<T>(val, item_ct1);
 | 
					+  val = warpReduceSum<T>(val, item_ct1);
 | 
				
			||||||
+  return val;
 | 
					+  return val;
 | 
				
			||||||
+}
 | 
					+}
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+} // namespace vllm
 | 
					+} // namespace vllm
 | 
				
			||||||
\ No newline at end of file
 | 
					\ No newline at end of file
 | 
				
			||||||
diff --git a/csrc/xpu/utils.cpp b/csrc/xpu/utils.cpp
 | 
					diff --git a/csrc/xpu/utils.cpp b/csrc/xpu/utils.cpp
 | 
				
			||||||
| 
						 | 
					@ -8692,7 +8692,7 @@ index 000000000..e98db9b65
 | 
				
			||||||
+        tensor_parallel_size=1,
 | 
					+        tensor_parallel_size=1,
 | 
				
			||||||
+    )
 | 
					+    )
 | 
				
			||||||
diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py
 | 
					diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py
 | 
				
			||||||
index c3d210c27..c3b6ca7eb 100644
 | 
					index c3d210c27..8dd101608 100644
 | 
				
			||||||
--- a/vllm/_ipex_ops.py
 | 
					--- a/vllm/_ipex_ops.py
 | 
				
			||||||
+++ b/vllm/_ipex_ops.py
 | 
					+++ b/vllm/_ipex_ops.py
 | 
				
			||||||
@@ -1,6 +1,4 @@
 | 
					@@ -1,6 +1,4 @@
 | 
				
			||||||
| 
						 | 
					@ -8780,10 +8780,10 @@ index c3d210c27..c3b6ca7eb 100644
 | 
				
			||||||
+        # todo: ipex will refactor namespace
 | 
					+        # todo: ipex will refactor namespace
 | 
				
			||||||
+        import vllm._C.ops
 | 
					+        import vllm._C.ops
 | 
				
			||||||
+        vllm._C.ops.paged_attention_v1(out, query,
 | 
					+        vllm._C.ops.paged_attention_v1(out, query,
 | 
				
			||||||
+                                     key_cache.view_as(value_cache),
 | 
					+                                       key_cache.view_as(value_cache),
 | 
				
			||||||
+                                     value_cache, num_kv_heads, scale,
 | 
					+                                       value_cache, num_kv_heads, scale,
 | 
				
			||||||
+                                     block_tables, context_lens, block_size,
 | 
					+                                       block_tables, context_lens, block_size,
 | 
				
			||||||
+                                     max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap)
 | 
					+                                       max_context_len, alibi_slopes, kv_cache_dtype, k_scale, logits_soft_cap)
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     @staticmethod
 | 
					     @staticmethod
 | 
				
			||||||
     def paged_attention_v2(
 | 
					     def paged_attention_v2(
 | 
				
			||||||
| 
						 | 
					@ -8929,7 +8929,7 @@ index c3d210c27..c3b6ca7eb 100644
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     @staticmethod
 | 
					     @staticmethod
 | 
				
			||||||
     def varlen_attention(
 | 
					     def varlen_attention(
 | 
				
			||||||
@@ -220,22 +262,233 @@ class ipex_ops:
 | 
					@@ -220,22 +262,250 @@ class ipex_ops:
 | 
				
			||||||
         kv_cache_dtype: str,
 | 
					         kv_cache_dtype: str,
 | 
				
			||||||
         k_scale: float,
 | 
					         k_scale: float,
 | 
				
			||||||
         v_scale: float,
 | 
					         v_scale: float,
 | 
				
			||||||
| 
						 | 
					@ -9044,30 +9044,47 @@ index c3d210c27..c3b6ca7eb 100644
 | 
				
			||||||
+        p_dropout: float,
 | 
					+        p_dropout: float,
 | 
				
			||||||
+        softmax_scale: float,
 | 
					+        softmax_scale: float,
 | 
				
			||||||
+        zero_tensors: bool,
 | 
					+        zero_tensors: bool,
 | 
				
			||||||
+        is_caual: bool,
 | 
					+        is_casual: bool,
 | 
				
			||||||
+        return_softmax: bool,
 | 
					+        return_softmax: bool,
 | 
				
			||||||
+        gen_: Optional[torch.Generator],
 | 
					+        gen_: Optional[torch.Generator],
 | 
				
			||||||
+    ):
 | 
					+    ):
 | 
				
			||||||
+        return torch.ops.torch_ipex.chunked_prefill(
 | 
					+        return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
 | 
				
			||||||
 | 
					+            output,
 | 
				
			||||||
+            query.contiguous(),
 | 
					+            query.contiguous(),
 | 
				
			||||||
+            key_cache,
 | 
					+            key_cache,
 | 
				
			||||||
+            value_cache,
 | 
					+            value_cache,
 | 
				
			||||||
+            output,
 | 
					 | 
				
			||||||
+            cu_seqlens_q,
 | 
					+            cu_seqlens_q,
 | 
				
			||||||
+            cu_seqlens_k,
 | 
					+            cu_seqlens_k,
 | 
				
			||||||
+            seq_used_k,
 | 
					 | 
				
			||||||
+            block_table,
 | 
					 | 
				
			||||||
+            alibi_slopes,
 | 
					 | 
				
			||||||
+            max_seqlen_q,
 | 
					+            max_seqlen_q,
 | 
				
			||||||
+            max_seqlen_k,
 | 
					+            max_seqlen_k,
 | 
				
			||||||
+            p_dropout,
 | 
					 | 
				
			||||||
+            softmax_scale,
 | 
					+            softmax_scale,
 | 
				
			||||||
+            zero_tensors,
 | 
					+            is_casual,
 | 
				
			||||||
+            is_caual,
 | 
					+            block_table,
 | 
				
			||||||
+            return_softmax,
 | 
					+            alibi_slopes,
 | 
				
			||||||
+            gen_,
 | 
					+            k_scale=1.0,
 | 
				
			||||||
 | 
					+            v_scale=1.0,
 | 
				
			||||||
         )
 | 
					         )
 | 
				
			||||||
 
 | 
					+        # return torch.ops.torch_ipex.chunked_prefill(
 | 
				
			||||||
 | 
					+        #     query.contiguous(),
 | 
				
			||||||
 | 
					+        #     key_cache,
 | 
				
			||||||
 | 
					+        #     value_cache,
 | 
				
			||||||
 | 
					+        #     output,
 | 
				
			||||||
 | 
					+        #     cu_seqlens_q,
 | 
				
			||||||
 | 
					+        #     cu_seqlens_k,
 | 
				
			||||||
 | 
					+        #     seq_used_k,
 | 
				
			||||||
 | 
					+        #     block_table,
 | 
				
			||||||
 | 
					+        #     alibi_slopes,
 | 
				
			||||||
 | 
					+        #     max_seqlen_q,
 | 
				
			||||||
 | 
					+        #     max_seqlen_k,
 | 
				
			||||||
 | 
					+        #     p_dropout,
 | 
				
			||||||
 | 
					+        #     softmax_scale,
 | 
				
			||||||
 | 
					+        #     zero_tensors,
 | 
				
			||||||
 | 
					+        #     is_caual,
 | 
				
			||||||
 | 
					+        #     return_softmax,
 | 
				
			||||||
 | 
					+        #     gen_,
 | 
				
			||||||
 | 
					+        # )
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
+    @staticmethod
 | 
					+    @staticmethod
 | 
				
			||||||
+    def copy_blocks(key_caches: List[torch.Tensor],
 | 
					+    def copy_blocks(key_caches: List[torch.Tensor],
 | 
				
			||||||
+                    value_caches: List[torch.Tensor],
 | 
					+                    value_caches: List[torch.Tensor],
 | 
				
			||||||
| 
						 | 
					@ -9078,7 +9095,7 @@ index c3d210c27..c3b6ca7eb 100644
 | 
				
			||||||
+        #     block_mapping,
 | 
					+        #     block_mapping,
 | 
				
			||||||
+        # )
 | 
					+        # )
 | 
				
			||||||
+        vllm._C.cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
 | 
					+        vllm._C.cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
 | 
				
			||||||
+
 | 
					 
 | 
				
			||||||
     @staticmethod
 | 
					     @staticmethod
 | 
				
			||||||
     def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
 | 
					     def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
 | 
				
			||||||
                     block_mapping: torch.Tensor) -> None:
 | 
					                     block_mapping: torch.Tensor) -> None:
 | 
				
			||||||
| 
						 | 
					@ -11666,6 +11683,143 @@ index 5649cf2dd..66e30984e 100644
 | 
				
			||||||
     if isinstance(load_config.load_format, type):
 | 
					     if isinstance(load_config.load_format, type):
 | 
				
			||||||
         return load_config.load_format(load_config)
 | 
					         return load_config.load_format(load_config)
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
 | 
					diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
 | 
				
			||||||
 | 
					index 6a3112b5f..7e2b7c862 100644
 | 
				
			||||||
 | 
					--- a/vllm/model_executor/models/baichuan.py
 | 
				
			||||||
 | 
					+++ b/vllm/model_executor/models/baichuan.py
 | 
				
			||||||
 | 
					@@ -47,7 +47,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
 | 
				
			||||||
 | 
					 from vllm.sequence import IntermediateTensors
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
 | 
				
			||||||
 | 
					-from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
 | 
				
			||||||
 | 
					+from .utils import (is_pp_missing_parameter,
 | 
				
			||||||
 | 
					                     make_empty_intermediate_tensors_factory, make_layers)
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					@@ -321,45 +321,6 @@ class BaiChuanModel(nn.Module):
 | 
				
			||||||
 | 
					         hidden_states, _ = self.norm(hidden_states, residual)
 | 
				
			||||||
 | 
					         return hidden_states
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-    def load_weights(self, weights: Iterable[Tuple[str,
 | 
				
			||||||
 | 
					-                                                   torch.Tensor]]) -> Set[str]:
 | 
				
			||||||
 | 
					-        stacked_params_mapping = [
 | 
				
			||||||
 | 
					-            # (param_name, shard_name, shard_id)
 | 
				
			||||||
 | 
					-            ("gate_up_proj", "gate_proj", 0),
 | 
				
			||||||
 | 
					-            ("gate_up_proj", "up_proj", 1),
 | 
				
			||||||
 | 
					-        ]
 | 
				
			||||||
 | 
					-        params_dict = dict(self.named_parameters())
 | 
				
			||||||
 | 
					-        loaded_params: Set[str] = set()
 | 
				
			||||||
 | 
					-        for name, loaded_weight in weights:
 | 
				
			||||||
 | 
					-            if "rotary_emb.inv_freq" in name:
 | 
				
			||||||
 | 
					-                continue
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					-            for (param_name, weight_name, shard_id) in stacked_params_mapping:
 | 
				
			||||||
 | 
					-                if weight_name not in name:
 | 
				
			||||||
 | 
					-                    continue
 | 
				
			||||||
 | 
					-                name = name.replace(weight_name, param_name)
 | 
				
			||||||
 | 
					-                # Skip loading extra bias for GPTQ models.
 | 
				
			||||||
 | 
					-                if name.endswith(".bias") and name not in params_dict:
 | 
				
			||||||
 | 
					-                    continue
 | 
				
			||||||
 | 
					-                if is_pp_missing_parameter(name, self):
 | 
				
			||||||
 | 
					-                    continue
 | 
				
			||||||
 | 
					-                param = params_dict[name]
 | 
				
			||||||
 | 
					-                weight_loader = param.weight_loader
 | 
				
			||||||
 | 
					-                weight_loader(param, loaded_weight, shard_id)
 | 
				
			||||||
 | 
					-                break
 | 
				
			||||||
 | 
					-            else:
 | 
				
			||||||
 | 
					-                # Skip loading extra bias for GPTQ models.
 | 
				
			||||||
 | 
					-                if name.endswith(".bias") and name not in params_dict:
 | 
				
			||||||
 | 
					-                    continue
 | 
				
			||||||
 | 
					-                if is_pp_missing_parameter(name, self):
 | 
				
			||||||
 | 
					-                    continue
 | 
				
			||||||
 | 
					-                param = params_dict[name]
 | 
				
			||||||
 | 
					-                weight_loader = getattr(param, "weight_loader",
 | 
				
			||||||
 | 
					-                                        default_weight_loader)
 | 
				
			||||||
 | 
					-                weight_loader(param, loaded_weight)
 | 
				
			||||||
 | 
					-            loaded_params.add(name)
 | 
				
			||||||
 | 
					-        return loaded_params
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
 | 
				
			||||||
 | 
					                               SupportsQuant):
 | 
				
			||||||
 | 
					@@ -392,7 +353,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
 | 
				
			||||||
 | 
					         self.lm_head = ParallelLMHead(config.vocab_size,
 | 
				
			||||||
 | 
					                                       config.hidden_size,
 | 
				
			||||||
 | 
					                                       quant_config=quant_config)
 | 
				
			||||||
 | 
					-        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
 | 
				
			||||||
 | 
					         if self.config.tie_word_embeddings:
 | 
				
			||||||
 | 
					             self.lm_head.weight = self.model.embed_tokens.weight
 | 
				
			||||||
 | 
					         self.logits_processor = LogitsProcessor(config.vocab_size)
 | 
				
			||||||
 | 
					@@ -433,22 +393,53 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     def load_weights(self, weights: Iterable[Tuple[str,
 | 
				
			||||||
 | 
					                                                    torch.Tensor]]) -> Set[str]:
 | 
				
			||||||
 | 
					-        loader = AutoWeightsLoader(self)
 | 
				
			||||||
 | 
					-        return loader.load_weights(weights)
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					-    def lm_head_weight_loader(self, param: nn.Parameter,
 | 
				
			||||||
 | 
					-                              loaded_weight: torch.Tensor):
 | 
				
			||||||
 | 
					-        # Unlike Baichuan, Baichuan2 normalizes the head weights.
 | 
				
			||||||
 | 
					-        # Refer to:
 | 
				
			||||||
 | 
					-        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
 | 
				
			||||||
 | 
					-        # Distinguish between Baichuan and Baichuan2 by checking the
 | 
				
			||||||
 | 
					-        # vocab size. This is suggested by
 | 
				
			||||||
 | 
					-        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
 | 
				
			||||||
 | 
					-        is_baichuan2 = self.config.vocab_size == 125696
 | 
				
			||||||
 | 
					-        if is_baichuan2:
 | 
				
			||||||
 | 
					-            loaded_weight = torch.nn.functional.normalize(loaded_weight)
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					-        default_weight_loader(param, loaded_weight)
 | 
				
			||||||
 | 
					+        stacked_params_mapping = [
 | 
				
			||||||
 | 
					+            # (param_name, shard_name, shard_id)
 | 
				
			||||||
 | 
					+            ("gate_up_proj", "gate_proj", 0),
 | 
				
			||||||
 | 
					+            ("gate_up_proj", "up_proj", 1),
 | 
				
			||||||
 | 
					+        ]
 | 
				
			||||||
 | 
					+        params_dict = dict(self.named_parameters())
 | 
				
			||||||
 | 
					+        loaded_params: Set[str] = set()
 | 
				
			||||||
 | 
					+        for name, loaded_weight in weights:
 | 
				
			||||||
 | 
					+            if "rotary_emb.inv_freq" in name:
 | 
				
			||||||
 | 
					+                continue
 | 
				
			||||||
 | 
					+            if name == "lm_head.weight":
 | 
				
			||||||
 | 
					+                # Unlike Baichuan, Baichuan2 normalizes the head weights.
 | 
				
			||||||
 | 
					+                # Refer to:
 | 
				
			||||||
 | 
					+                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
 | 
				
			||||||
 | 
					+                # Distinguish between Baichuan and Baichuan2 by checking the
 | 
				
			||||||
 | 
					+                # vocab size. This is suggested by
 | 
				
			||||||
 | 
					+                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
 | 
				
			||||||
 | 
					+                is_baichuan2 = self.config.vocab_size == 125696
 | 
				
			||||||
 | 
					+                if is_baichuan2:
 | 
				
			||||||
 | 
					+                    loaded_weight = torch.nn.functional.normalize(
 | 
				
			||||||
 | 
					+                        loaded_weight)
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
 | 
				
			||||||
 | 
					+                if weight_name not in name:
 | 
				
			||||||
 | 
					+                    continue
 | 
				
			||||||
 | 
					+                name = name.replace(weight_name, param_name)
 | 
				
			||||||
 | 
					+                # Skip loading extra bias for GPTQ models.
 | 
				
			||||||
 | 
					+                if name.endswith(".bias") and name not in params_dict:
 | 
				
			||||||
 | 
					+                    continue
 | 
				
			||||||
 | 
					+                if is_pp_missing_parameter(name, self):
 | 
				
			||||||
 | 
					+                    continue
 | 
				
			||||||
 | 
					+                param = params_dict[name]
 | 
				
			||||||
 | 
					+                weight_loader = param.weight_loader
 | 
				
			||||||
 | 
					+                weight_loader(param, loaded_weight, shard_id)
 | 
				
			||||||
 | 
					+                break
 | 
				
			||||||
 | 
					+            else:
 | 
				
			||||||
 | 
					+                # Skip loading extra bias for GPTQ models.
 | 
				
			||||||
 | 
					+                if name.endswith(".bias") and name not in params_dict:
 | 
				
			||||||
 | 
					+                    continue
 | 
				
			||||||
 | 
					+                if is_pp_missing_parameter(name, self):
 | 
				
			||||||
 | 
					+                    continue
 | 
				
			||||||
 | 
					+                param = params_dict[name]
 | 
				
			||||||
 | 
					+                weight_loader = getattr(param, "weight_loader",
 | 
				
			||||||
 | 
					+                                        default_weight_loader)
 | 
				
			||||||
 | 
					+                weight_loader(param, loaded_weight)
 | 
				
			||||||
 | 
					+            loaded_params.add(name)
 | 
				
			||||||
 | 
					+        return loaded_params
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
 | 
				
			||||||
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
 | 
					diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
 | 
				
			||||||
index 1b1738f88..2c2ed67b9 100644
 | 
					index 1b1738f88..2c2ed67b9 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/chatglm.py
 | 
					--- a/vllm/model_executor/models/chatglm.py
 | 
				
			||||||
| 
						 | 
					@ -14147,7 +14301,7 @@ index c0a3c59ba..8614c2273 100644
 | 
				
			||||||
     "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
 | 
					     "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
 | 
				
			||||||
     # [Encoder-decoder]
 | 
					     # [Encoder-decoder]
 | 
				
			||||||
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
 | 
					diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
 | 
				
			||||||
index cecad9e89..df4cf4776 100644
 | 
					index cecad9e89..7eaabd1db 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/siglip.py
 | 
					--- a/vllm/model_executor/models/siglip.py
 | 
				
			||||||
+++ b/vllm/model_executor/models/siglip.py
 | 
					+++ b/vllm/model_executor/models/siglip.py
 | 
				
			||||||
@@ -140,6 +140,74 @@ class SiglipVisionEmbeddings(nn.Module):
 | 
					@@ -140,6 +140,74 @@ class SiglipVisionEmbeddings(nn.Module):
 | 
				
			||||||
| 
						 | 
					@ -14195,9 +14349,9 @@ index cecad9e89..df4cf4776 100644
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+        query, key, value = (x.transpose(1, 2)
 | 
					+        query, key, value = (x.transpose(1, 2)
 | 
				
			||||||
+                                for x in (query, key, value))
 | 
					+                                for x in (query, key, value))
 | 
				
			||||||
+        from ipex_llm.transformers.models.utils import use_sdp_causal
 | 
					 | 
				
			||||||
+        from vllm.attention.backends.ipex_attn import use_sdp_causal
 | 
					+        from vllm.attention.backends.ipex_attn import use_sdp_causal
 | 
				
			||||||
+        import xe_addons, math
 | 
					+        import xe_addons, math
 | 
				
			||||||
 | 
					+        from vllm.attention.backends.abstract import AttentionType
 | 
				
			||||||
+        mask = None
 | 
					+        mask = None
 | 
				
			||||||
+        scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale
 | 
					+        scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale
 | 
				
			||||||
+        from ipex_llm.transformers.models.common import padding_qkv_hd
 | 
					+        from ipex_llm.transformers.models.common import padding_qkv_hd
 | 
				
			||||||
| 
						 | 
					@ -14209,7 +14363,7 @@ index cecad9e89..df4cf4776 100644
 | 
				
			||||||
+            query, key, value,
 | 
					+            query, key, value,
 | 
				
			||||||
+            self.head_size, num
 | 
					+            self.head_size, num
 | 
				
			||||||
+        )
 | 
					+        )
 | 
				
			||||||
+        if use_sdp_causal(query.shape[-1], query, 0):
 | 
					+        if use_sdp_causal(query.shape[-1], query, 0, AttentionType.DECODER):
 | 
				
			||||||
+            out = xe_addons.sdp_non_causal(query.contiguous(), key.contiguous(), value.contiguous(), mask, scale)[:, :, :, :self.head_size].transpose(1, 2)
 | 
					+            out = xe_addons.sdp_non_causal(query.contiguous(), key.contiguous(), value.contiguous(), mask, scale)[:, :, :, :self.head_size].transpose(1, 2)
 | 
				
			||||||
+        # import torch.nn.functional as F
 | 
					+        # import torch.nn.functional as F
 | 
				
			||||||
+        # out = F.scaled_dot_product_attention(query,
 | 
					+        # out = F.scaled_dot_product_attention(query,
 | 
				
			||||||
| 
						 | 
					@ -14239,10 +14393,23 @@ index cecad9e89..df4cf4776 100644
 | 
				
			||||||
     def forward(
 | 
					     def forward(
 | 
				
			||||||
         self,
 | 
					         self,
 | 
				
			||||||
diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py
 | 
					diff --git a/vllm/model_executor/models/telechat2.py b/vllm/model_executor/models/telechat2.py
 | 
				
			||||||
index a38035e37..9631fbd83 100644
 | 
					index a38035e37..570f2bcdd 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/telechat2.py
 | 
					--- a/vllm/model_executor/models/telechat2.py
 | 
				
			||||||
+++ b/vllm/model_executor/models/telechat2.py
 | 
					+++ b/vllm/model_executor/models/telechat2.py
 | 
				
			||||||
@@ -44,9 +44,9 @@ class TeleChat2Model(LlamaModel):
 | 
					@@ -22,10 +22,12 @@
 | 
				
			||||||
 | 
					 from typing import Iterable, Set, Tuple
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 import torch
 | 
				
			||||||
 | 
					+import torch.nn as nn
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 from vllm.config import VllmConfig
 | 
				
			||||||
 | 
					 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 | 
				
			||||||
 | 
					 from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
 | 
				
			||||||
 | 
					+from .llama import LlamaDecoderLayer
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
 | 
				
			||||||
 | 
					                     is_pp_missing_parameter)
 | 
				
			||||||
 | 
					@@ -44,9 +46,9 @@ class TeleChat2Model(LlamaModel):
 | 
				
			||||||
         for layer in self.layers:
 | 
					         for layer in self.layers:
 | 
				
			||||||
             if not isinstance(layer, PPMissingLayer):
 | 
					             if not isinstance(layer, PPMissingLayer):
 | 
				
			||||||
                 layer.self_attn.qkv_proj.bias = None
 | 
					                 layer.self_attn.qkv_proj.bias = None
 | 
				
			||||||
| 
						 | 
					@ -14254,6 +14421,18 @@ index a38035e37..9631fbd83 100644
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     def load_weights(self, weights: Iterable[Tuple[str,
 | 
					     def load_weights(self, weights: Iterable[Tuple[str,
 | 
				
			||||||
                                                    torch.Tensor]]) -> Set[str]:
 | 
					                                                    torch.Tensor]]) -> Set[str]:
 | 
				
			||||||
 | 
					@@ -120,7 +122,10 @@ class TeleChat2ForCausalLM(LlamaForCausalLM):
 | 
				
			||||||
 | 
					         },
 | 
				
			||||||
 | 
					     )
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					-    def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
 | 
				
			||||||
 | 
					+    def _init_model(self,
 | 
				
			||||||
 | 
					+                    vllm_config: VllmConfig,
 | 
				
			||||||
 | 
					+                    prefix: str = "",
 | 
				
			||||||
 | 
					+                    layer_type: type[nn.Module] = LlamaDecoderLayer):
 | 
				
			||||||
 | 
					         return TeleChat2Model(vllm_config=vllm_config, prefix=prefix)
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     def load_weights(self, weights: Iterable[Tuple[str,
 | 
				
			||||||
diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
 | 
					diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py
 | 
				
			||||||
index fc0fb8929..6454e7006 100644
 | 
					index fc0fb8929..6454e7006 100644
 | 
				
			||||||
--- a/vllm/multimodal/utils.py
 | 
					--- a/vllm/multimodal/utils.py
 | 
				
			||||||
| 
						 | 
					@ -14319,7 +14498,7 @@ index b6f6029de..b90fea9fd 100644
 | 
				
			||||||
     def is_neuron(self) -> bool:
 | 
					     def is_neuron(self) -> bool:
 | 
				
			||||||
         return self._enum == PlatformEnum.NEURON
 | 
					         return self._enum == PlatformEnum.NEURON
 | 
				
			||||||
diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
 | 
					diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py
 | 
				
			||||||
index 225e756cd..4fd7fe220 100644
 | 
					index 225e756cd..25b83549a 100644
 | 
				
			||||||
--- a/vllm/platforms/xpu.py
 | 
					--- a/vllm/platforms/xpu.py
 | 
				
			||||||
+++ b/vllm/platforms/xpu.py
 | 
					+++ b/vllm/platforms/xpu.py
 | 
				
			||||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
 | 
					@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
 | 
				
			||||||
| 
						 | 
					@ -14330,7 +14509,17 @@ index 225e756cd..4fd7fe220 100644
 | 
				
			||||||
 from vllm.logger import init_logger
 | 
					 from vllm.logger import init_logger
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
 from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
 | 
					 from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
 | 
				
			||||||
@@ -33,8 +34,13 @@ class XPUPlatform(Platform):
 | 
					@@ -25,6 +26,9 @@ class XPUPlatform(Platform):
 | 
				
			||||||
 | 
					     # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
 | 
				
			||||||
 | 
					     ray_device_key: str = "GPU"
 | 
				
			||||||
 | 
					     device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"
 | 
				
			||||||
 | 
					+    additional_env_vars: list[str] = [
 | 
				
			||||||
 | 
					+        "IPEX_LLM_LOWBIT",
 | 
				
			||||||
 | 
					+    ]
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     @classmethod
 | 
				
			||||||
 | 
					     def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
 | 
				
			||||||
 | 
					@@ -33,8 +37,13 @@ class XPUPlatform(Platform):
 | 
				
			||||||
                              use_mla: bool) -> str:
 | 
					                              use_mla: bool) -> str:
 | 
				
			||||||
         if selected_backend != _Backend.IPEX:
 | 
					         if selected_backend != _Backend.IPEX:
 | 
				
			||||||
             logger.info("Cannot use %s backend on XPU.", selected_backend)
 | 
					             logger.info("Cannot use %s backend on XPU.", selected_backend)
 | 
				
			||||||
| 
						 | 
					@ -14346,7 +14535,7 @@ index 225e756cd..4fd7fe220 100644
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     @staticmethod
 | 
					     @staticmethod
 | 
				
			||||||
     def get_device_capability(
 | 
					     def get_device_capability(
 | 
				
			||||||
@@ -63,6 +69,8 @@ class XPUPlatform(Platform):
 | 
					@@ -63,6 +72,8 @@ class XPUPlatform(Platform):
 | 
				
			||||||
     @classmethod
 | 
					     @classmethod
 | 
				
			||||||
     def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
 | 
					     def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
 | 
				
			||||||
         cache_config = vllm_config.cache_config
 | 
					         cache_config = vllm_config.cache_config
 | 
				
			||||||
| 
						 | 
					@ -14355,7 +14544,7 @@ index 225e756cd..4fd7fe220 100644
 | 
				
			||||||
         if cache_config and cache_config.block_size is None:
 | 
					         if cache_config and cache_config.block_size is None:
 | 
				
			||||||
             cache_config.block_size = 16
 | 
					             cache_config.block_size = 16
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
@@ -87,31 +95,46 @@ class XPUPlatform(Platform):
 | 
					@@ -87,31 +98,46 @@ class XPUPlatform(Platform):
 | 
				
			||||||
             raise NotImplementedError(
 | 
					             raise NotImplementedError(
 | 
				
			||||||
                 "XPU does not support speculative decoding")
 | 
					                 "XPU does not support speculative decoding")
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
| 
						 | 
					@ -14412,6 +14601,15 @@ index 225e756cd..4fd7fe220 100644
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     @classmethod
 | 
					     @classmethod
 | 
				
			||||||
     def is_pin_memory_available(cls):
 | 
					     def is_pin_memory_available(cls):
 | 
				
			||||||
 | 
					@@ -140,3 +166,7 @@ class XPUPlatform(Platform):
 | 
				
			||||||
 | 
					     @classmethod
 | 
				
			||||||
 | 
					     def get_device_communicator_cls(cls) -> str:
 | 
				
			||||||
 | 
					         return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"  # noqa
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    @classmethod
 | 
				
			||||||
 | 
					+    def use_all_gather(cls) -> bool:
 | 
				
			||||||
 | 
					+        return False
 | 
				
			||||||
 | 
					\ No newline at end of file
 | 
				
			||||||
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
 | 
					diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
 | 
				
			||||||
index 53699341b..6bc039068 100644
 | 
					index 53699341b..6bc039068 100644
 | 
				
			||||||
--- a/vllm/transformers_utils/configs/__init__.py
 | 
					--- a/vllm/transformers_utils/configs/__init__.py
 | 
				
			||||||
| 
						 | 
					@ -14432,6 +14630,35 @@ index 53699341b..6bc039068 100644
 | 
				
			||||||
     "ChatGLMConfig",
 | 
					     "ChatGLMConfig",
 | 
				
			||||||
     "Cohere2Config",
 | 
					     "Cohere2Config",
 | 
				
			||||||
     "DbrxConfig",
 | 
					     "DbrxConfig",
 | 
				
			||||||
 | 
					diff --git a/vllm/utils.py b/vllm/utils.py
 | 
				
			||||||
 | 
					index 5f32f8cb6..2ee0c1906 100644
 | 
				
			||||||
 | 
					--- a/vllm/utils.py
 | 
				
			||||||
 | 
					+++ b/vllm/utils.py
 | 
				
			||||||
 | 
					@@ -128,6 +128,8 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
 | 
				
			||||||
 | 
					     "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
 | 
				
			||||||
 | 
					 }
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+BMG_TARGET_IDS = ["0xe20b", "0xe210"]
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					 # Constants related to forcing the attention backend selection
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					 # String name of register which may be set in order to
 | 
				
			||||||
 | 
					@@ -2564,3 +2566,14 @@ def sha256(input) -> int:
 | 
				
			||||||
 | 
					     input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
 | 
				
			||||||
 | 
					     return int.from_bytes(hashlib.sha256(input_bytes).digest(),
 | 
				
			||||||
 | 
					                           byteorder="big")
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+@cache
 | 
				
			||||||
 | 
					+def is_bmg_platform():
 | 
				
			||||||
 | 
					+    if not torch.xpu.is_available():
 | 
				
			||||||
 | 
					+        raise ValueError("Cannot detect the usage of XPU!")
 | 
				
			||||||
 | 
					+    device_index = torch.xpu.current_device()
 | 
				
			||||||
 | 
					+    device_name = torch.xpu.get_device_name(device_index)
 | 
				
			||||||
 | 
					+    for target_id in BMG_TARGET_IDS:
 | 
				
			||||||
 | 
					+        if target_id in device_name:
 | 
				
			||||||
 | 
					+            return True
 | 
				
			||||||
 | 
					+    return False
 | 
				
			||||||
 | 
					\ No newline at end of file
 | 
				
			||||||
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
 | 
					diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
 | 
				
			||||||
index c271f438e..cf7180606 100755
 | 
					index c271f438e..cf7180606 100755
 | 
				
			||||||
--- a/vllm/v1/attention/backends/flash_attn.py
 | 
					--- a/vllm/v1/attention/backends/flash_attn.py
 | 
				
			||||||
| 
						 | 
					@ -14457,10 +14684,10 @@ index c271f438e..cf7180606 100755
 | 
				
			||||||
     assert sliding_window == (-1, -1), (
 | 
					     assert sliding_window == (-1, -1), (
 | 
				
			||||||
diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
 | 
					diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
 | 
				
			||||||
new file mode 100644
 | 
					new file mode 100644
 | 
				
			||||||
index 000000000..29cde02f3
 | 
					index 000000000..f4a435eaa
 | 
				
			||||||
--- /dev/null
 | 
					--- /dev/null
 | 
				
			||||||
+++ b/vllm/v1/attention/backends/ipex_attn.py
 | 
					+++ b/vllm/v1/attention/backends/ipex_attn.py
 | 
				
			||||||
@@ -0,0 +1,358 @@
 | 
					@@ -0,0 +1,392 @@
 | 
				
			||||||
+from dataclasses import dataclass
 | 
					+from dataclasses import dataclass
 | 
				
			||||||
+from typing import Any, Dict, List, Optional, Tuple, Type
 | 
					+from typing import Any, Dict, List, Optional, Tuple, Type
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
| 
						 | 
					@ -14474,6 +14701,7 @@ index 000000000..29cde02f3
 | 
				
			||||||
+from vllm.attention.ops.paged_attn import (PagedAttention,
 | 
					+from vllm.attention.ops.paged_attn import (PagedAttention,
 | 
				
			||||||
+                                           PagedAttentionMetadata)
 | 
					+                                           PagedAttentionMetadata)
 | 
				
			||||||
+from vllm.attention.backends.ipex_attn import use_gqa_kernel
 | 
					+from vllm.attention.backends.ipex_attn import use_gqa_kernel
 | 
				
			||||||
 | 
					+from vllm.utils import is_bmg_platform
 | 
				
			||||||
+import os
 | 
					+import os
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+@dataclass
 | 
					+@dataclass
 | 
				
			||||||
| 
						 | 
					@ -14509,9 +14737,9 @@ index 000000000..29cde02f3
 | 
				
			||||||
+        # if block_size % 16 != 0:
 | 
					+        # if block_size % 16 != 0:
 | 
				
			||||||
+            # raise ValueError("Block size must be a multiple of 16.")
 | 
					+            # raise ValueError("Block size must be a multiple of 16.")
 | 
				
			||||||
+        # This needs to be changed...
 | 
					+        # This needs to be changed...
 | 
				
			||||||
+        # return (2, num_blocks, block_size, num_kv_heads, head_size)
 | 
					+        return (2, num_blocks, block_size, num_kv_heads, head_size)
 | 
				
			||||||
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
 | 
					+        # return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
 | 
				
			||||||
+                                                 num_kv_heads, head_size)
 | 
					+        #                                          num_kv_heads, head_size)
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
| 
						 | 
					@ -14557,6 +14785,8 @@ index 000000000..29cde02f3
 | 
				
			||||||
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 | 
					+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+        support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes()
 | 
					+        support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes()
 | 
				
			||||||
 | 
					+        self.using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap)
 | 
				
			||||||
 | 
					+        self.is_bmg_platform = is_bmg_platform()
 | 
				
			||||||
+        if head_size not in support_head_sizes:
 | 
					+        if head_size not in support_head_sizes:
 | 
				
			||||||
+            raise ValueError(
 | 
					+            raise ValueError(
 | 
				
			||||||
+                f"Head size {head_size} is not supported by FlashAttention. "
 | 
					+                f"Head size {head_size} is not supported by FlashAttention. "
 | 
				
			||||||
| 
						 | 
					@ -14567,7 +14797,6 @@ index 000000000..29cde02f3
 | 
				
			||||||
+                                      "are not implemented for "
 | 
					+                                      "are not implemented for "
 | 
				
			||||||
+                                      "IpexAttnBackendImpl")
 | 
					+                                      "IpexAttnBackendImpl")
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+    # TODO(gc): Refine this logic..., because of bad performance...
 | 
					 | 
				
			||||||
+    def forward(
 | 
					+    def forward(
 | 
				
			||||||
+        self,
 | 
					+        self,
 | 
				
			||||||
+        layer: AttentionLayer,
 | 
					+        layer: AttentionLayer,
 | 
				
			||||||
| 
						 | 
					@ -14610,6 +14839,8 @@ index 000000000..29cde02f3
 | 
				
			||||||
+            k_scale,
 | 
					+            k_scale,
 | 
				
			||||||
+            v_scale,
 | 
					+            v_scale,
 | 
				
			||||||
+            self.scale,
 | 
					+            self.scale,
 | 
				
			||||||
 | 
					+            self.using_gqa_kernel,
 | 
				
			||||||
 | 
					+            self.is_bmg_platform,
 | 
				
			||||||
+            self.sliding_window,
 | 
					+            self.sliding_window,
 | 
				
			||||||
+            self.alibi_slopes,
 | 
					+            self.alibi_slopes,
 | 
				
			||||||
+            self.logits_soft_cap,
 | 
					+            self.logits_soft_cap,
 | 
				
			||||||
| 
						 | 
					@ -14682,6 +14913,8 @@ index 000000000..29cde02f3
 | 
				
			||||||
+    k_scale: float,
 | 
					+    k_scale: float,
 | 
				
			||||||
+    v_scale: float,
 | 
					+    v_scale: float,
 | 
				
			||||||
+    scale: float,
 | 
					+    scale: float,
 | 
				
			||||||
 | 
					+    using_gqa_kernel: bool,
 | 
				
			||||||
 | 
					+    is_bmg_platform: bool,
 | 
				
			||||||
+    sliding_window: Optional[List[int]] = None,
 | 
					+    sliding_window: Optional[List[int]] = None,
 | 
				
			||||||
+    alibi_slopes: Optional[torch.Tensor] = None,
 | 
					+    alibi_slopes: Optional[torch.Tensor] = None,
 | 
				
			||||||
+    logits_soft_cap: Optional[float] = None,
 | 
					+    logits_soft_cap: Optional[float] = None,
 | 
				
			||||||
| 
						 | 
					@ -14700,54 +14933,82 @@ index 000000000..29cde02f3
 | 
				
			||||||
+    key = key.view(-1, num_kv_heads, head_size)
 | 
					+    key = key.view(-1, num_kv_heads, head_size)
 | 
				
			||||||
+    value = value.view(-1, num_kv_heads, head_size)
 | 
					+    value = value.view(-1, num_kv_heads, head_size)
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+    using_gqa_kernel = use_gqa_kernel(num_heads, num_kv_heads, head_size, logits_soft_cap)
 | 
					+    if is_bmg_platform:
 | 
				
			||||||
+
 | 
					+        key_cache, value_cache = kv_cache.unbind(0)
 | 
				
			||||||
+
 | 
					+        ipex_ops.reshape_and_cache_flash(
 | 
				
			||||||
+    if using_gqa_kernel:
 | 
					+            key[:num_actual_tokens],
 | 
				
			||||||
+        key_cache, value_cache = split_kv_cache_ipexllm(
 | 
					+            value[:num_actual_tokens],
 | 
				
			||||||
 | 
					+            key_cache,
 | 
				
			||||||
 | 
					+            value_cache,
 | 
				
			||||||
 | 
					+            attn_metadata.slot_mapping,
 | 
				
			||||||
 | 
					+            kv_cache_dtype,
 | 
				
			||||||
 | 
					+            k_scale,
 | 
				
			||||||
 | 
					+            v_scale,
 | 
				
			||||||
 | 
					+        )
 | 
				
			||||||
 | 
					+        ipex_ops.chunked_prefill(
 | 
				
			||||||
 | 
					+            query[:num_actual_tokens].contiguous(),
 | 
				
			||||||
 | 
					+            key_cache,
 | 
				
			||||||
 | 
					+            value_cache,
 | 
				
			||||||
 | 
					+            output[:num_actual_tokens],
 | 
				
			||||||
 | 
					+            attn_metadata.query_start_loc,
 | 
				
			||||||
 | 
					+            attn_metadata.seq_start_loc,
 | 
				
			||||||
 | 
					+            None,
 | 
				
			||||||
 | 
					+            attn_metadata.block_table,
 | 
				
			||||||
 | 
					+            alibi_slopes,
 | 
				
			||||||
 | 
					+            attn_metadata.max_query_len,
 | 
				
			||||||
 | 
					+            attn_metadata.max_seq_len,
 | 
				
			||||||
 | 
					+            0.0,
 | 
				
			||||||
 | 
					+            scale,
 | 
				
			||||||
 | 
					+            False,
 | 
				
			||||||
 | 
					+            True,
 | 
				
			||||||
 | 
					+            False,
 | 
				
			||||||
 | 
					+            None,
 | 
				
			||||||
 | 
					+        )
 | 
				
			||||||
 | 
					+    else:
 | 
				
			||||||
 | 
					+        if using_gqa_kernel:
 | 
				
			||||||
 | 
					+            key_cache, value_cache = split_kv_cache_ipexllm(
 | 
				
			||||||
 | 
					+                    kv_cache, num_kv_heads, head_size)
 | 
				
			||||||
 | 
					+            ipex_ops.reshape_and_cache_ipexllm(
 | 
				
			||||||
 | 
					+                key[:num_actual_tokens],
 | 
				
			||||||
 | 
					+                value[:num_actual_tokens],
 | 
				
			||||||
 | 
					+                key_cache,
 | 
				
			||||||
 | 
					+                value_cache,
 | 
				
			||||||
 | 
					+                attn_metadata.slot_mapping.flatten(),
 | 
				
			||||||
 | 
					+                kv_cache_dtype,
 | 
				
			||||||
 | 
					+                k_scale,
 | 
				
			||||||
 | 
					+                v_scale,
 | 
				
			||||||
 | 
					+            )
 | 
				
			||||||
 | 
					+        else:
 | 
				
			||||||
 | 
					+            key_cache, value_cache = split_kv_cache(
 | 
				
			||||||
+                kv_cache, num_kv_heads, head_size)
 | 
					+                kv_cache, num_kv_heads, head_size)
 | 
				
			||||||
+        ipex_ops.reshape_and_cache_ipexllm(
 | 
					+            ipex_ops.reshape_and_cache(
 | 
				
			||||||
+            key[:num_actual_tokens],
 | 
					+                key[:num_actual_tokens],
 | 
				
			||||||
+            value[:num_actual_tokens],
 | 
					+                value[:num_actual_tokens],
 | 
				
			||||||
+            key_cache,
 | 
					+                key_cache,
 | 
				
			||||||
+            value_cache,
 | 
					+                value_cache,
 | 
				
			||||||
+            attn_metadata.slot_mapping.flatten(),
 | 
					+                attn_metadata.slot_mapping.flatten(),
 | 
				
			||||||
+            kv_cache_dtype,
 | 
					+                kv_cache_dtype,
 | 
				
			||||||
+            k_scale,
 | 
					+                k_scale,
 | 
				
			||||||
+            v_scale,
 | 
					+                v_scale,
 | 
				
			||||||
+        )
 | 
					+            )
 | 
				
			||||||
+    else:
 | 
					+        # Invoke chunked prefill method...
 | 
				
			||||||
+        key_cache, value_cache = split_kv_cache(
 | 
					+        import vllm._C.ops
 | 
				
			||||||
+            kv_cache, num_kv_heads, head_size)   
 | 
					+        assert head_size == 128 or head_size == 64
 | 
				
			||||||
+        ipex_ops.reshape_and_cache(
 | 
					+        value = os.environ.get('USE_CONTEXT_V1')
 | 
				
			||||||
+            key[:num_actual_tokens],
 | 
					+        query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1]
 | 
				
			||||||
+            value[:num_actual_tokens],
 | 
					+        seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1]
 | 
				
			||||||
+            key_cache,
 | 
					+        context_len = seq_len - query_len
 | 
				
			||||||
+            value_cache,
 | 
					+        if using_gqa_kernel:
 | 
				
			||||||
+            attn_metadata.slot_mapping.flatten(),
 | 
					+            # if using_gqa_kernel, then only the v1 kernel can be used
 | 
				
			||||||
+            kv_cache_dtype,
 | 
					+            out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())
 | 
				
			||||||
+            k_scale,
 | 
					+        elif value is None:
 | 
				
			||||||
+            v_scale,
 | 
					+            # Otherwise, by default use v2 attention forward kernel...
 | 
				
			||||||
+        )
 | 
					+            out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item())
 | 
				
			||||||
+    # Invoke chunked prefill method...
 | 
					+        else:
 | 
				
			||||||
+    import vllm._C.ops
 | 
					+            out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())
 | 
				
			||||||
+    assert head_size == 128 or head_size == 64
 | 
					 | 
				
			||||||
+    value = os.environ.get('USE_CONTEXT_V1')
 | 
					 | 
				
			||||||
+    query_len = attn_metadata.query_start_loc[1:] - attn_metadata.query_start_loc[:-1]
 | 
					 | 
				
			||||||
+    seq_len = attn_metadata.seq_start_loc[1:] - attn_metadata.seq_start_loc[:-1]
 | 
					 | 
				
			||||||
+    context_len = seq_len - query_len
 | 
					 | 
				
			||||||
+    if using_gqa_kernel:
 | 
					 | 
				
			||||||
+        # if using_gqa_kernel, then only the v1 kernel can be used
 | 
					 | 
				
			||||||
+        out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())
 | 
					 | 
				
			||||||
+    elif value is None:
 | 
					 | 
				
			||||||
+        # Otherwise, by default use v2 attention forward kernel...
 | 
					 | 
				
			||||||
+        out = vllm._C.ops.context_attention_forward_v2(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item(), torch.amax(query_len).item())
 | 
					 | 
				
			||||||
+    else:
 | 
					 | 
				
			||||||
+        out = vllm._C.ops.context_attention_forward_v1(query[:num_actual_tokens], key_cache, value_cache, attn_metadata.block_table, attn_metadata.query_start_loc, seq_len, context_len, attn_metadata.max_seq_len, torch.amax(context_len).item())
 | 
					 | 
				
			||||||
+    
 | 
					 | 
				
			||||||
+    # output[:num_actual_tokens] = out
 | 
					 | 
				
			||||||
+    output[:num_actual_tokens] = out.view(out.shape[0], -1)
 | 
					 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
 | 
					+        # output[:num_actual_tokens] = out
 | 
				
			||||||
 | 
					+        output[:num_actual_tokens] = out.view(out.shape[0], -1)
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
| 
						 | 
					@ -15648,10 +15909,10 @@ index 000000000..8612d3d77
 | 
				
			||||||
+            self.kv_caches)
 | 
					+            self.kv_caches)
 | 
				
			||||||
diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
 | 
					diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py
 | 
				
			||||||
new file mode 100644
 | 
					new file mode 100644
 | 
				
			||||||
index 000000000..1bc531e28
 | 
					index 000000000..1fb0dca87
 | 
				
			||||||
--- /dev/null
 | 
					--- /dev/null
 | 
				
			||||||
+++ b/vllm/v1/worker/xpu_worker.py
 | 
					+++ b/vllm/v1/worker/xpu_worker.py
 | 
				
			||||||
@@ -0,0 +1,168 @@
 | 
					@@ -0,0 +1,175 @@
 | 
				
			||||||
+# SPDX-License-Identifier: Apache-2.0
 | 
					+# SPDX-License-Identifier: Apache-2.0
 | 
				
			||||||
+import os
 | 
					+import os
 | 
				
			||||||
+from typing import Optional
 | 
					+from typing import Optional
 | 
				
			||||||
| 
						 | 
					@ -15685,9 +15946,16 @@ index 000000000..1bc531e28
 | 
				
			||||||
+        assert device_config.device_type == "xpu"
 | 
					+        assert device_config.device_type == "xpu"
 | 
				
			||||||
+        assert current_platform.is_xpu()
 | 
					+        assert current_platform.is_xpu()
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+    def load_model(self) -> None:
 | 
					+        import os
 | 
				
			||||||
+        self.model_runner.load_model()
 | 
					+        lowbit = os.getenv("IPEX_LLM_LOWBIT", None)
 | 
				
			||||||
 | 
					+        if lowbit is not None:
 | 
				
			||||||
 | 
					+            from ipex_llm.vllm.xpu.model_convert import _ipex_llm_convert
 | 
				
			||||||
 | 
					+            _ipex_llm_convert(lowbit)
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+    def compile_or_warm_up_model(self) -> None:
 | 
				
			||||||
 | 
					+        pass
 | 
				
			||||||
 | 
					+        
 | 
				
			||||||
+    # we provide this function due to `torch.xpu.mem_get_info()` doesn't
 | 
					+    # we provide this function due to `torch.xpu.mem_get_info()` doesn't
 | 
				
			||||||
+    # return correct free_gpu_memory on intel client GPU. We need to
 | 
					+    # return correct free_gpu_memory on intel client GPU. We need to
 | 
				
			||||||
+    # calculate/estiamte it.
 | 
					+    # calculate/estiamte it.
 | 
				
			||||||
| 
						 | 
					@ -15838,7 +16106,7 @@ index 86e6d9752..ad80bf54e 100644
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
 @dataclass(frozen=True)
 | 
					 @dataclass(frozen=True)
 | 
				
			||||||
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
 | 
					diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
index 9d49b4385..67f07f5b1 100644
 | 
					index 9d49b4385..7396b0c89 100644
 | 
				
			||||||
--- a/vllm/worker/xpu_model_runner.py
 | 
					--- a/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
+++ b/vllm/worker/xpu_model_runner.py
 | 
					+++ b/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
@@ -5,8 +5,8 @@ import time
 | 
					@@ -5,8 +5,8 @@ import time
 | 
				
			||||||
| 
						 | 
					@ -16163,7 +16431,7 @@ index 9d49b4385..67f07f5b1 100644
 | 
				
			||||||
+        slot_mapping_tensor = torch.tensor(slot_mapping,
 | 
					+        slot_mapping_tensor = torch.tensor(slot_mapping,
 | 
				
			||||||
+                                           dtype=torch.long,
 | 
					+                                           dtype=torch.long,
 | 
				
			||||||
+                                           device=self.device)
 | 
					+                                           device=self.device)
 | 
				
			||||||
+        if need_block_table:
 | 
					+        if need_block_table or "bge" in self.runner.model_config.model.lower():
 | 
				
			||||||
+            seq_lens_tensor = torch.tensor(seq_lens,
 | 
					+            seq_lens_tensor = torch.tensor(seq_lens,
 | 
				
			||||||
+                                        dtype=torch.int,
 | 
					+                                        dtype=torch.int,
 | 
				
			||||||
+                                        device=self.device)
 | 
					+                                        device=self.device)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue