update vllm patch (#13211)
Co-authored-by: gc-fu <guancheng.fu@intel.com>
This commit is contained in:
		
							parent
							
								
									ac04992278
								
							
						
					
					
						commit
						5a629ae470
					
				
					 1 changed files with 147 additions and 41 deletions
				
			
		| 
						 | 
					@ -10389,7 +10389,7 @@ index bd52fc90b..7d4e3555a 100644
 | 
				
			||||||
                 if capability < quant_config.get_min_capability():
 | 
					                 if capability < quant_config.get_min_capability():
 | 
				
			||||||
                     raise ValueError(
 | 
					                     raise ValueError(
 | 
				
			||||||
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
 | 
					diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
 | 
				
			||||||
index 89c9b6747..a5be57ce0 100644
 | 
					index 89c9b6747..feba4f69f 100644
 | 
				
			||||||
--- a/vllm/engine/arg_utils.py
 | 
					--- a/vllm/engine/arg_utils.py
 | 
				
			||||||
+++ b/vllm/engine/arg_utils.py
 | 
					+++ b/vllm/engine/arg_utils.py
 | 
				
			||||||
@@ -210,6 +210,8 @@ class EngineArgs:
 | 
					@@ -210,6 +210,8 @@ class EngineArgs:
 | 
				
			||||||
| 
						 | 
					@ -10420,7 +10420,7 @@ index 89c9b6747..a5be57ce0 100644
 | 
				
			||||||
         parser.add_argument(
 | 
					         parser.add_argument(
 | 
				
			||||||
             "--disable-cascade-attn",
 | 
					             "--disable-cascade-attn",
 | 
				
			||||||
             action="store_true",
 | 
					             action="store_true",
 | 
				
			||||||
@@ -1061,6 +1075,8 @@ class EngineArgs:
 | 
					@@ -1061,10 +1075,16 @@ class EngineArgs:
 | 
				
			||||||
             override_generation_config=self.override_generation_config,
 | 
					             override_generation_config=self.override_generation_config,
 | 
				
			||||||
             enable_sleep_mode=self.enable_sleep_mode,
 | 
					             enable_sleep_mode=self.enable_sleep_mode,
 | 
				
			||||||
             model_impl=self.model_impl,
 | 
					             model_impl=self.model_impl,
 | 
				
			||||||
| 
						 | 
					@ -10429,7 +10429,26 @@ index 89c9b6747..a5be57ce0 100644
 | 
				
			||||||
         )
 | 
					         )
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     def create_load_config(self) -> LoadConfig:
 | 
					     def create_load_config(self) -> LoadConfig:
 | 
				
			||||||
@@ -1504,12 +1520,13 @@ class EngineArgs:
 | 
					 
 | 
				
			||||||
 | 
					+        use_low_bit_loader = False
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        if self.low_bit_model_path is not None:
 | 
				
			||||||
 | 
					+            use_low_bit_loader = True
 | 
				
			||||||
 | 
					         if(self.qlora_adapter_name_or_path is not None) and \
 | 
				
			||||||
 | 
					             self.quantization != "bitsandbytes":
 | 
				
			||||||
 | 
					             raise ValueError(
 | 
				
			||||||
 | 
					@@ -1079,8 +1099,10 @@ class EngineArgs:
 | 
				
			||||||
 | 
					             model_loader_extra_config=self.model_loader_extra_config,
 | 
				
			||||||
 | 
					             ignore_patterns=self.ignore_patterns,
 | 
				
			||||||
 | 
					             use_tqdm_on_load=self.use_tqdm_on_load,
 | 
				
			||||||
 | 
					+            use_low_bit_loader=use_low_bit_loader,
 | 
				
			||||||
 | 
					         )
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					     def create_speculative_config(
 | 
				
			||||||
 | 
					         self,
 | 
				
			||||||
 | 
					         target_model_config: ModelConfig,
 | 
				
			||||||
 | 
					@@ -1504,12 +1526,13 @@ class EngineArgs:
 | 
				
			||||||
             _raise_or_fallback(feature_name=name, recommend_to_remove=True)
 | 
					             _raise_or_fallback(feature_name=name, recommend_to_remove=True)
 | 
				
			||||||
             return False
 | 
					             return False
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
| 
						 | 
					@ -12669,6 +12688,23 @@ index c190a4585..dda2a96cc 100644
 | 
				
			||||||
         boi = self.boi.expand(x.shape[0], -1, -1)
 | 
					         boi = self.boi.expand(x.shape[0], -1, -1)
 | 
				
			||||||
         eoi = self.eoi.expand(x.shape[0], -1, -1)
 | 
					         eoi = self.eoi.expand(x.shape[0], -1, -1)
 | 
				
			||||||
         x = torch.cat((boi, x, eoi), dim=1)
 | 
					         x = torch.cat((boi, x, eoi), dim=1)
 | 
				
			||||||
 | 
					diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
 | 
				
			||||||
 | 
					index cb0379c10..5e8b22ab0 100644
 | 
				
			||||||
 | 
					--- a/vllm/model_executor/models/idefics2_vision_model.py
 | 
				
			||||||
 | 
					+++ b/vllm/model_executor/models/idefics2_vision_model.py
 | 
				
			||||||
 | 
					@@ -144,8 +144,10 @@ class Idefics2VisionAttention(nn.Module):
 | 
				
			||||||
 | 
					         )
 | 
				
			||||||
 | 
					         self.tp_size = get_tensor_model_parallel_world_size()
 | 
				
			||||||
 | 
					         self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
 | 
				
			||||||
 | 
					-        self.attn = MultiHeadAttention(self.num_heads_per_partition,
 | 
				
			||||||
 | 
					-                                       self.head_dim, self.scale)
 | 
				
			||||||
 | 
					+        # self.attn = MultiHeadAttention(self.num_heads_per_partition,
 | 
				
			||||||
 | 
					+        #                                self.head_dim, self.scale)
 | 
				
			||||||
 | 
					+        from vllm.model_executor.models.siglip import SelfAttention
 | 
				
			||||||
 | 
					+        self.attn = SelfAttention(self.num_heads_per_partition, self.head_dim, self.scale)
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     def forward(
 | 
				
			||||||
 | 
					         self,
 | 
				
			||||||
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
 | 
					diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
 | 
				
			||||||
index 5fab9df3f..f8e6fbe24 100644
 | 
					index 5fab9df3f..f8e6fbe24 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/minicpmv.py
 | 
					--- a/vllm/model_executor/models/minicpmv.py
 | 
				
			||||||
| 
						 | 
					@ -13552,6 +13588,18 @@ index 000000000..d96085f46
 | 
				
			||||||
+            hidden_states=encoder_outputs.hidden_states,
 | 
					+            hidden_states=encoder_outputs.hidden_states,
 | 
				
			||||||
+            attentions=encoder_outputs.attentions,
 | 
					+            attentions=encoder_outputs.attentions,
 | 
				
			||||||
+        )
 | 
					+        )
 | 
				
			||||||
 | 
					diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py
 | 
				
			||||||
 | 
					index db90848f9..5eabcf653 100644
 | 
				
			||||||
 | 
					--- a/vllm/model_executor/models/phi4mm_audio.py
 | 
				
			||||||
 | 
					+++ b/vllm/model_executor/models/phi4mm_audio.py
 | 
				
			||||||
 | 
					@@ -230,6 +230,7 @@ class ConformerEncoderLayer(nn.Module):
 | 
				
			||||||
 | 
					         x = x + 0.5 * self.feed_forward_in(x)
 | 
				
			||||||
 | 
					         norm_x = self.layer_norm_att(x)
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					+        mask = mask.to(x.device)
 | 
				
			||||||
 | 
					         x = x + self.self_attn(
 | 
				
			||||||
 | 
					             norm_x,
 | 
				
			||||||
 | 
					             norm_x,
 | 
				
			||||||
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
 | 
					diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
 | 
				
			||||||
index c4d02e5dd..2831a5a12 100644
 | 
					index c4d02e5dd..2831a5a12 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/qwen2.py
 | 
					--- a/vllm/model_executor/models/qwen2.py
 | 
				
			||||||
| 
						 | 
					@ -13589,41 +13637,85 @@ index c4d02e5dd..2831a5a12 100644
 | 
				
			||||||
         )
 | 
					         )
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
 | 
					diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
 | 
				
			||||||
index 1e6ff1fec..e2480326a 100644
 | 
					index 1e6ff1fec..90ebe5ca9 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/qwen2_5_vl.py
 | 
					--- a/vllm/model_executor/models/qwen2_5_vl.py
 | 
				
			||||||
+++ b/vllm/model_executor/models/qwen2_5_vl.py
 | 
					+++ b/vllm/model_executor/models/qwen2_5_vl.py
 | 
				
			||||||
@@ -304,6 +304,10 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
					@@ -302,23 +302,33 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
				
			||||||
 | 
					                                       "(b s) ... -> b s ...",
 | 
				
			||||||
 | 
					                                       b=batch_size)
 | 
				
			||||||
         elif self.attn_backend == _Backend.TORCH_SDPA:
 | 
					         elif self.attn_backend == _Backend.TORCH_SDPA:
 | 
				
			||||||
             # Execute attention entry by entry for speed & less VRAM.
 | 
					-            # Execute attention entry by entry for speed & less VRAM.
 | 
				
			||||||
             outputs = []
 | 
					-            outputs = []
 | 
				
			||||||
+            head_dim = q.shape[-1]
 | 
					-            for i in range(1, len(cu_seqlens)):
 | 
				
			||||||
+            import math
 | 
					-                start_idx = cu_seqlens[i - 1]
 | 
				
			||||||
+            import xe_addons
 | 
					-                end_idx = cu_seqlens[i]
 | 
				
			||||||
+            scale = 1 / math.sqrt(head_dim)
 | 
					-                q_i = q[:, start_idx:end_idx]
 | 
				
			||||||
             for i in range(1, len(cu_seqlens)):
 | 
					-                k_i = k[:, start_idx:end_idx]
 | 
				
			||||||
                 start_idx = cu_seqlens[i - 1]
 | 
					-                v_i = v[:, start_idx:end_idx]
 | 
				
			||||||
                 end_idx = cu_seqlens[i]
 | 
					-                q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
 | 
				
			||||||
@@ -312,10 +316,16 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
					-                                 for x in [q_i, k_i, v_i])
 | 
				
			||||||
                 v_i = v[:, start_idx:end_idx]
 | 
					 | 
				
			||||||
                 q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
 | 
					 | 
				
			||||||
                                  for x in [q_i, k_i, v_i])
 | 
					 | 
				
			||||||
-                output_i = F.scaled_dot_product_attention(q_i,
 | 
					-                output_i = F.scaled_dot_product_attention(q_i,
 | 
				
			||||||
-                                                          k_i,
 | 
					-                                                          k_i,
 | 
				
			||||||
-                                                          v_i,
 | 
					-                                                          v_i,
 | 
				
			||||||
-                                                          dropout_p=0.0)
 | 
					-                                                          dropout_p=0.0)
 | 
				
			||||||
+                # output_i = F.scaled_dot_product_attention(q_i,
 | 
					-                output_i = rearrange(output_i, "b h s d -> b s h d ")
 | 
				
			||||||
+                #                                           k_i,
 | 
					-                outputs.append(output_i)
 | 
				
			||||||
+                #                                           v_i,
 | 
					-            context_layer = torch.cat(outputs, dim=1)
 | 
				
			||||||
+                #                                           dropout_p=0.0)
 | 
					+            # TODO(xiangyu): Maybe add attn_backend xpu?
 | 
				
			||||||
+                output_i = xe_addons.sdp_non_causal(
 | 
					+            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
 | 
				
			||||||
+                    q_i.contiguous(),
 | 
					+            from vllm._ipex_ops import ipex_ops
 | 
				
			||||||
+                    k_i.contiguous(),
 | 
					+            output = torch.empty(
 | 
				
			||||||
+                    v_i.contiguous(),
 | 
					+                        (q.shape[0], q.shape[1], q.shape[2]),
 | 
				
			||||||
+                    None,
 | 
					+                        dtype=q.dtype,
 | 
				
			||||||
+                    scale)
 | 
					+                        device=q.device)
 | 
				
			||||||
                 output_i = rearrange(output_i, "b h s d -> b s h d ")
 | 
					+            import math
 | 
				
			||||||
                 outputs.append(output_i)
 | 
					+            head_dim = q.shape[-1]
 | 
				
			||||||
             context_layer = torch.cat(outputs, dim=1)
 | 
					+            scale = 1 / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					+            ipex_ops.varlen_attention(q, k, v, output,
 | 
				
			||||||
 | 
					+                                    cu_seqlens,
 | 
				
			||||||
 | 
					+                                    cu_seqlens,
 | 
				
			||||||
 | 
					+                                    max_seqlen,
 | 
				
			||||||
 | 
					+                                    max_seqlen,
 | 
				
			||||||
 | 
					+                                    pdropout=0,
 | 
				
			||||||
 | 
					+                                    softmax_scale=scale,
 | 
				
			||||||
 | 
					+                                    zero_tensors=False,
 | 
				
			||||||
 | 
					+                                    is_causal=False,
 | 
				
			||||||
 | 
					+                                    return_softmax=False,
 | 
				
			||||||
 | 
					+                                    gen_=None,
 | 
				
			||||||
 | 
					+                                    logits_soft_cap=0
 | 
				
			||||||
 | 
					+                                    )
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+            context_layer = rearrange(output,
 | 
				
			||||||
 | 
					+                                      "(b s) ... -> b s ...",
 | 
				
			||||||
 | 
					+                                      b=batch_size)
 | 
				
			||||||
 | 
					         elif self.attn_backend == _Backend.XFORMERS:
 | 
				
			||||||
 | 
					             from xformers import ops as xops
 | 
				
			||||||
 | 
					             from xformers.ops.fmha.attn_bias import BlockDiagonalMask
 | 
				
			||||||
 | 
					@@ -613,10 +623,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
 | 
				
			||||||
 | 
					         cu_seqlens: torch.Tensor,
 | 
				
			||||||
 | 
					     ) -> tuple[Optional[int], Optional[list[int]]]:
 | 
				
			||||||
 | 
					         max_seqlen, seqlens = None, None
 | 
				
			||||||
 | 
					-        if self.attn_backend == _Backend.FLASH_ATTN:
 | 
				
			||||||
 | 
					-            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
 | 
				
			||||||
 | 
					-        elif self.attn_backend == _Backend.XFORMERS:
 | 
				
			||||||
 | 
					-            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
 | 
				
			||||||
 | 
					+        # if self.attn_backend == _Backend.FLASH_ATTN:
 | 
				
			||||||
 | 
					+        #     max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
 | 
				
			||||||
 | 
					+        # elif self.attn_backend == _Backend.XFORMERS:
 | 
				
			||||||
 | 
					+        #     seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
 | 
				
			||||||
 | 
					+        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
 | 
				
			||||||
 | 
					         return max_seqlen, seqlens
 | 
				
			||||||
 | 
					 
 | 
				
			||||||
 | 
					     def forward(
 | 
				
			||||||
 | 
					@@ -1082,7 +1093,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
 | 
				
			||||||
 | 
					                     image_input=image_input,
 | 
				
			||||||
 | 
					                     video_input=video_input)
 | 
				
			||||||
 | 
					                 input_ids = None
 | 
				
			||||||
 | 
					-
 | 
				
			||||||
 | 
					+ 
 | 
				
			||||||
 | 
					         hidden_states = self.language_model.model(
 | 
				
			||||||
 | 
					             input_ids=input_ids,
 | 
				
			||||||
 | 
					             positions=positions,
 | 
				
			||||||
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
 | 
					diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
 | 
				
			||||||
index a7800d415..26af87512 100644
 | 
					index a7800d415..26af87512 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/qwen2_vl.py
 | 
					--- a/vllm/model_executor/models/qwen2_vl.py
 | 
				
			||||||
| 
						 | 
					@ -15133,10 +15225,10 @@ index c271f438e..cf7180606 100755
 | 
				
			||||||
     assert sliding_window == (-1, -1), (
 | 
					     assert sliding_window == (-1, -1), (
 | 
				
			||||||
diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
 | 
					diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
 | 
				
			||||||
new file mode 100644
 | 
					new file mode 100644
 | 
				
			||||||
index 000000000..f4a435eaa
 | 
					index 000000000..964696cfe
 | 
				
			||||||
--- /dev/null
 | 
					--- /dev/null
 | 
				
			||||||
+++ b/vllm/v1/attention/backends/ipex_attn.py
 | 
					+++ b/vllm/v1/attention/backends/ipex_attn.py
 | 
				
			||||||
@@ -0,0 +1,392 @@
 | 
					@@ -0,0 +1,404 @@
 | 
				
			||||||
+from dataclasses import dataclass
 | 
					+from dataclasses import dataclass
 | 
				
			||||||
+from typing import Any, Dict, List, Optional, Tuple, Type
 | 
					+from typing import Any, Dict, List, Optional, Tuple, Type
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
| 
						 | 
					@ -15152,6 +15244,10 @@ index 000000000..f4a435eaa
 | 
				
			||||||
+from vllm.attention.backends.ipex_attn import use_gqa_kernel
 | 
					+from vllm.attention.backends.ipex_attn import use_gqa_kernel
 | 
				
			||||||
+from vllm.utils import is_bmg_platform
 | 
					+from vllm.utils import is_bmg_platform
 | 
				
			||||||
+import os
 | 
					+import os
 | 
				
			||||||
 | 
					+from vllm.logger import init_logger
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+logger = init_logger(__name__)
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+@dataclass
 | 
					+@dataclass
 | 
				
			||||||
+class IPEXAttentionMetadata(FlashAttentionMetadata):
 | 
					+class IPEXAttentionMetadata(FlashAttentionMetadata):
 | 
				
			||||||
| 
						 | 
					@ -15246,6 +15342,12 @@ index 000000000..f4a435eaa
 | 
				
			||||||
+                                      "are not implemented for "
 | 
					+                                      "are not implemented for "
 | 
				
			||||||
+                                      "IpexAttnBackendImpl")
 | 
					+                                      "IpexAttnBackendImpl")
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
 | 
					+        flag = os.getenv("IPEX_LLM_PREFILL_VARLEN_BACKEND", None)
 | 
				
			||||||
 | 
					+        self.ipex_varlen_attn = False
 | 
				
			||||||
 | 
					+        if flag is not None:
 | 
				
			||||||
 | 
					+            self.ipex_varlen_attn = True
 | 
				
			||||||
 | 
					+            logger.info_once(f"V1 engine using varlen_attention for prefilling.")
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
+    def forward(
 | 
					+    def forward(
 | 
				
			||||||
+        self,
 | 
					+        self,
 | 
				
			||||||
+        layer: AttentionLayer,
 | 
					+        layer: AttentionLayer,
 | 
				
			||||||
| 
						 | 
					@ -15293,6 +15395,7 @@ index 000000000..f4a435eaa
 | 
				
			||||||
+            self.sliding_window,
 | 
					+            self.sliding_window,
 | 
				
			||||||
+            self.alibi_slopes,
 | 
					+            self.alibi_slopes,
 | 
				
			||||||
+            self.logits_soft_cap,
 | 
					+            self.logits_soft_cap,
 | 
				
			||||||
 | 
					+            self.ipex_varlen_attn,
 | 
				
			||||||
+        )
 | 
					+        )
 | 
				
			||||||
+        return output.view(-1, self.num_heads * self.head_size)
 | 
					+        return output.view(-1, self.num_heads * self.head_size)
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
| 
						 | 
					@ -15367,6 +15470,7 @@ index 000000000..f4a435eaa
 | 
				
			||||||
+    sliding_window: Optional[List[int]] = None,
 | 
					+    sliding_window: Optional[List[int]] = None,
 | 
				
			||||||
+    alibi_slopes: Optional[torch.Tensor] = None,
 | 
					+    alibi_slopes: Optional[torch.Tensor] = None,
 | 
				
			||||||
+    logits_soft_cap: Optional[float] = None,
 | 
					+    logits_soft_cap: Optional[float] = None,
 | 
				
			||||||
 | 
					+    flag: Optional[bool] = False,
 | 
				
			||||||
+) -> None:
 | 
					+) -> None:
 | 
				
			||||||
+    context = get_forward_context()
 | 
					+    context = get_forward_context()
 | 
				
			||||||
+    current_metadata = context.attn_metadata
 | 
					+    current_metadata = context.attn_metadata
 | 
				
			||||||
| 
						 | 
					@ -15382,7 +15486,7 @@ index 000000000..f4a435eaa
 | 
				
			||||||
+    key = key.view(-1, num_kv_heads, head_size)
 | 
					+    key = key.view(-1, num_kv_heads, head_size)
 | 
				
			||||||
+    value = value.view(-1, num_kv_heads, head_size)
 | 
					+    value = value.view(-1, num_kv_heads, head_size)
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+    if is_bmg_platform:
 | 
					+    if flag or is_bmg_platform:
 | 
				
			||||||
+        key_cache, value_cache = kv_cache.unbind(0)
 | 
					+        key_cache, value_cache = kv_cache.unbind(0)
 | 
				
			||||||
+        ipex_ops.reshape_and_cache_flash(
 | 
					+        ipex_ops.reshape_and_cache_flash(
 | 
				
			||||||
+            key[:num_actual_tokens],
 | 
					+            key[:num_actual_tokens],
 | 
				
			||||||
| 
						 | 
					@ -17087,7 +17191,7 @@ index 000000000..dffc7b367
 | 
				
			||||||
+        return (attn_metadata, encoder_input_tokens_tensor,
 | 
					+        return (attn_metadata, encoder_input_tokens_tensor,
 | 
				
			||||||
+                encoder_input_positions_tensor)
 | 
					+                encoder_input_positions_tensor)
 | 
				
			||||||
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
 | 
					diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
index 9d49b4385..78e0c54f2 100644
 | 
					index 9d49b4385..dc5e95f4e 100644
 | 
				
			||||||
--- a/vllm/worker/xpu_model_runner.py
 | 
					--- a/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
+++ b/vllm/worker/xpu_model_runner.py
 | 
					+++ b/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
@@ -5,8 +5,8 @@ import time
 | 
					@@ -5,8 +5,8 @@ import time
 | 
				
			||||||
| 
						 | 
					@ -17735,15 +17839,17 @@ index 9d49b4385..78e0c54f2 100644
 | 
				
			||||||
         max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
 | 
					         max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
 | 
				
			||||||
             self.model_config)
 | 
					             self.model_config)
 | 
				
			||||||
         if max_mm_tokens > 0:
 | 
					         if max_mm_tokens > 0:
 | 
				
			||||||
@@ -461,6 +820,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
					@@ -461,6 +820,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
				
			||||||
                     "Computed max_num_seqs (%s) to be less than 1. "
 | 
					                     "Computed max_num_seqs (%s) to be less than 1. "
 | 
				
			||||||
                     "Setting it to the minimum value of 1.", expr)
 | 
					                     "Setting it to the minimum value of 1.", expr)
 | 
				
			||||||
                 max_num_seqs = 1
 | 
					                 max_num_seqs = 1
 | 
				
			||||||
+        '''
 | 
					+        '''
 | 
				
			||||||
 | 
					+        if "phi4mm" in self.model_config.hf_config.model_type:
 | 
				
			||||||
 | 
					+            max_num_seqs = 1
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
         batch_size = 0
 | 
					         batch_size = 0
 | 
				
			||||||
         for group_id in range(max_num_seqs):
 | 
					         for group_id in range(max_num_seqs):
 | 
				
			||||||
@@ -479,11 +839,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
					@@ -479,11 +841,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
				
			||||||
                 seq_data={group_id: dummy_data.seq_data},
 | 
					                 seq_data={group_id: dummy_data.seq_data},
 | 
				
			||||||
                 sampling_params=sampling_params,
 | 
					                 sampling_params=sampling_params,
 | 
				
			||||||
                 block_tables=None,
 | 
					                 block_tables=None,
 | 
				
			||||||
| 
						 | 
					@ -17759,7 +17865,7 @@ index 9d49b4385..78e0c54f2 100644
 | 
				
			||||||
         finished_requests_ids = [seq.request_id for seq in seqs]
 | 
					         finished_requests_ids = [seq.request_id for seq in seqs]
 | 
				
			||||||
         model_input = self.prepare_model_input(
 | 
					         model_input = self.prepare_model_input(
 | 
				
			||||||
             seqs, finished_requests_ids=finished_requests_ids)
 | 
					             seqs, finished_requests_ids=finished_requests_ids)
 | 
				
			||||||
@@ -493,25 +856,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
					@@ -493,25 +858,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
				
			||||||
                 batch_size=batch_size,
 | 
					                 batch_size=batch_size,
 | 
				
			||||||
                 dtype=self.model_config.dtype,
 | 
					                 dtype=self.model_config.dtype,
 | 
				
			||||||
                 device=self.device)
 | 
					                 device=self.device)
 | 
				
			||||||
| 
						 | 
					@ -17810,7 +17916,7 @@ index 9d49b4385..78e0c54f2 100644
 | 
				
			||||||
         """Helper method to prepare the model input based on a given sequence
 | 
					         """Helper method to prepare the model input based on a given sequence
 | 
				
			||||||
         group. Prepares metadata needed for the base model forward pass but not
 | 
					         group. Prepares metadata needed for the base model forward pass but not
 | 
				
			||||||
         metadata for possible additional steps, e.g., sampling.
 | 
					         metadata for possible additional steps, e.g., sampling.
 | 
				
			||||||
@@ -524,6 +901,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
					@@ -524,6 +903,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
         return builder.build()  # type: ignore
 | 
					         return builder.build()  # type: ignore
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
| 
						 | 
					@ -17833,7 +17939,7 @@ index 9d49b4385..78e0c54f2 100644
 | 
				
			||||||
     def prepare_model_input(
 | 
					     def prepare_model_input(
 | 
				
			||||||
         self,
 | 
					         self,
 | 
				
			||||||
         seq_group_metadata_list: List[SequenceGroupMetadata],
 | 
					         seq_group_metadata_list: List[SequenceGroupMetadata],
 | 
				
			||||||
@@ -563,6 +956,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
					@@ -563,6 +958,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
				
			||||||
             raise ValueError(
 | 
					             raise ValueError(
 | 
				
			||||||
                 "XPUModelRunner does not support multi-step execution.")
 | 
					                 "XPUModelRunner does not support multi-step execution.")
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
| 
						 | 
					@ -17846,7 +17952,7 @@ index 9d49b4385..78e0c54f2 100644
 | 
				
			||||||
         model_executable = self.model
 | 
					         model_executable = self.model
 | 
				
			||||||
         if (self.observability_config is not None
 | 
					         if (self.observability_config is not None
 | 
				
			||||||
                 and self.observability_config.collect_model_forward_time):
 | 
					                 and self.observability_config.collect_model_forward_time):
 | 
				
			||||||
@@ -612,3 +1011,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
					@@ -612,3 +1013,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
 | 
				
			||||||
             output.model_forward_time = model_forward_time
 | 
					             output.model_forward_time = model_forward_time
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
         return [output]
 | 
					         return [output]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue