update vllm patch (#13211)

Co-authored-by: gc-fu <guancheng.fu@intel.com>
This commit is contained in:
Shaojun Liu 2025-06-06 17:20:45 +08:00 committed by GitHub
parent ac04992278
commit 5a629ae470
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10389,7 +10389,7 @@ index bd52fc90b..7d4e3555a 100644
if capability < quant_config.get_min_capability(): if capability < quant_config.get_min_capability():
raise ValueError( raise ValueError(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 89c9b6747..a5be57ce0 100644 index 89c9b6747..feba4f69f 100644
--- a/vllm/engine/arg_utils.py --- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py
@@ -210,6 +210,8 @@ class EngineArgs: @@ -210,6 +210,8 @@ class EngineArgs:
@ -10420,7 +10420,7 @@ index 89c9b6747..a5be57ce0 100644
parser.add_argument( parser.add_argument(
"--disable-cascade-attn", "--disable-cascade-attn",
action="store_true", action="store_true",
@@ -1061,6 +1075,8 @@ class EngineArgs: @@ -1061,10 +1075,16 @@ class EngineArgs:
override_generation_config=self.override_generation_config, override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode, enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl, model_impl=self.model_impl,
@ -10429,7 +10429,26 @@ index 89c9b6747..a5be57ce0 100644
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
@@ -1504,12 +1520,13 @@ class EngineArgs:
+ use_low_bit_loader = False
+
+ if self.low_bit_model_path is not None:
+ use_low_bit_loader = True
if(self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
@@ -1079,8 +1099,10 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
+ use_low_bit_loader=use_low_bit_loader,
)
+
def create_speculative_config(
self,
target_model_config: ModelConfig,
@@ -1504,12 +1526,13 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=True) _raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False return False
@ -12669,6 +12688,23 @@ index c190a4585..dda2a96cc 100644
boi = self.boi.expand(x.shape[0], -1, -1) boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1) eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1) x = torch.cat((boi, x, eoi), dim=1)
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index cb0379c10..5e8b22ab0 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -144,8 +144,10 @@ class Idefics2VisionAttention(nn.Module):
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
- self.attn = MultiHeadAttention(self.num_heads_per_partition,
- self.head_dim, self.scale)
+ # self.attn = MultiHeadAttention(self.num_heads_per_partition,
+ # self.head_dim, self.scale)
+ from vllm.model_executor.models.siglip import SelfAttention
+ self.attn = SelfAttention(self.num_heads_per_partition, self.head_dim, self.scale)
def forward(
self,
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 5fab9df3f..f8e6fbe24 100644 index 5fab9df3f..f8e6fbe24 100644
--- a/vllm/model_executor/models/minicpmv.py --- a/vllm/model_executor/models/minicpmv.py
@ -13552,6 +13588,18 @@ index 000000000..d96085f46
+ hidden_states=encoder_outputs.hidden_states, + hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions, + attentions=encoder_outputs.attentions,
+ ) + )
diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py
index db90848f9..5eabcf653 100644
--- a/vllm/model_executor/models/phi4mm_audio.py
+++ b/vllm/model_executor/models/phi4mm_audio.py
@@ -230,6 +230,7 @@ class ConformerEncoderLayer(nn.Module):
x = x + 0.5 * self.feed_forward_in(x)
norm_x = self.layer_norm_att(x)
+ mask = mask.to(x.device)
x = x + self.self_attn(
norm_x,
norm_x,
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index c4d02e5dd..2831a5a12 100644 index c4d02e5dd..2831a5a12 100644
--- a/vllm/model_executor/models/qwen2.py --- a/vllm/model_executor/models/qwen2.py
@ -13589,41 +13637,85 @@ index c4d02e5dd..2831a5a12 100644
) )
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 1e6ff1fec..e2480326a 100644 index 1e6ff1fec..90ebe5ca9 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py --- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -304,6 +304,10 @@ class Qwen2_5_VisionAttention(nn.Module): @@ -302,23 +302,33 @@ class Qwen2_5_VisionAttention(nn.Module):
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA: elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM. - # Execute attention entry by entry for speed & less VRAM.
outputs = [] - outputs = []
+ head_dim = q.shape[-1] - for i in range(1, len(cu_seqlens)):
+ import math - start_idx = cu_seqlens[i - 1]
+ import xe_addons - end_idx = cu_seqlens[i]
+ scale = 1 / math.sqrt(head_dim) - q_i = q[:, start_idx:end_idx]
for i in range(1, len(cu_seqlens)): - k_i = k[:, start_idx:end_idx]
start_idx = cu_seqlens[i - 1] - v_i = v[:, start_idx:end_idx]
end_idx = cu_seqlens[i] - q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
@@ -312,10 +316,16 @@ class Qwen2_5_VisionAttention(nn.Module): - for x in [q_i, k_i, v_i])
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
- output_i = F.scaled_dot_product_attention(q_i, - output_i = F.scaled_dot_product_attention(q_i,
- k_i, - k_i,
- v_i, - v_i,
- dropout_p=0.0) - dropout_p=0.0)
+ # output_i = F.scaled_dot_product_attention(q_i, - output_i = rearrange(output_i, "b h s d -> b s h d ")
+ # k_i, - outputs.append(output_i)
+ # v_i, - context_layer = torch.cat(outputs, dim=1)
+ # dropout_p=0.0) + # TODO(xiangyu): Maybe add attn_backend xpu?
+ output_i = xe_addons.sdp_non_causal( + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
+ q_i.contiguous(), + from vllm._ipex_ops import ipex_ops
+ k_i.contiguous(), + output = torch.empty(
+ v_i.contiguous(), + (q.shape[0], q.shape[1], q.shape[2]),
+ None, + dtype=q.dtype,
+ scale) + device=q.device)
output_i = rearrange(output_i, "b h s d -> b s h d ") + import math
outputs.append(output_i) + head_dim = q.shape[-1]
context_layer = torch.cat(outputs, dim=1) + scale = 1 / math.sqrt(head_dim)
+ ipex_ops.varlen_attention(q, k, v, output,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ pdropout=0,
+ softmax_scale=scale,
+ zero_tensors=False,
+ is_causal=False,
+ return_softmax=False,
+ gen_=None,
+ logits_soft_cap=0
+ )
+
+ context_layer = rearrange(output,
+ "(b s) ... -> b s ...",
+ b=batch_size)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -613,10 +623,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
- if self.attn_backend == _Backend.FLASH_ATTN:
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == _Backend.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ # if self.attn_backend == _Backend.FLASH_ATTN:
+ # max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ # elif self.attn_backend == _Backend.XFORMERS:
+ # seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens
def forward(
@@ -1082,7 +1093,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input=image_input,
video_input=video_input)
input_ids = None
-
+
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index a7800d415..26af87512 100644 index a7800d415..26af87512 100644
--- a/vllm/model_executor/models/qwen2_vl.py --- a/vllm/model_executor/models/qwen2_vl.py
@ -15133,10 +15225,10 @@ index c271f438e..cf7180606 100755
assert sliding_window == (-1, -1), ( assert sliding_window == (-1, -1), (
diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
new file mode 100644 new file mode 100644
index 000000000..f4a435eaa index 000000000..964696cfe
--- /dev/null --- /dev/null
+++ b/vllm/v1/attention/backends/ipex_attn.py +++ b/vllm/v1/attention/backends/ipex_attn.py
@@ -0,0 +1,392 @@ @@ -0,0 +1,404 @@
+from dataclasses import dataclass +from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type
+ +
@ -15152,6 +15244,10 @@ index 000000000..f4a435eaa
+from vllm.attention.backends.ipex_attn import use_gqa_kernel +from vllm.attention.backends.ipex_attn import use_gqa_kernel
+from vllm.utils import is_bmg_platform +from vllm.utils import is_bmg_platform
+import os +import os
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+ +
+@dataclass +@dataclass
+class IPEXAttentionMetadata(FlashAttentionMetadata): +class IPEXAttentionMetadata(FlashAttentionMetadata):
@ -15246,6 +15342,12 @@ index 000000000..f4a435eaa
+ "are not implemented for " + "are not implemented for "
+ "IpexAttnBackendImpl") + "IpexAttnBackendImpl")
+ +
+ flag = os.getenv("IPEX_LLM_PREFILL_VARLEN_BACKEND", None)
+ self.ipex_varlen_attn = False
+ if flag is not None:
+ self.ipex_varlen_attn = True
+ logger.info_once(f"V1 engine using varlen_attention for prefilling.")
+
+ def forward( + def forward(
+ self, + self,
+ layer: AttentionLayer, + layer: AttentionLayer,
@ -15293,6 +15395,7 @@ index 000000000..f4a435eaa
+ self.sliding_window, + self.sliding_window,
+ self.alibi_slopes, + self.alibi_slopes,
+ self.logits_soft_cap, + self.logits_soft_cap,
+ self.ipex_varlen_attn,
+ ) + )
+ return output.view(-1, self.num_heads * self.head_size) + return output.view(-1, self.num_heads * self.head_size)
+ +
@ -15367,6 +15470,7 @@ index 000000000..f4a435eaa
+ sliding_window: Optional[List[int]] = None, + sliding_window: Optional[List[int]] = None,
+ alibi_slopes: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None,
+ logits_soft_cap: Optional[float] = None, + logits_soft_cap: Optional[float] = None,
+ flag: Optional[bool] = False,
+) -> None: +) -> None:
+ context = get_forward_context() + context = get_forward_context()
+ current_metadata = context.attn_metadata + current_metadata = context.attn_metadata
@ -15382,7 +15486,7 @@ index 000000000..f4a435eaa
+ key = key.view(-1, num_kv_heads, head_size) + key = key.view(-1, num_kv_heads, head_size)
+ value = value.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size)
+ +
+ if is_bmg_platform: + if flag or is_bmg_platform:
+ key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = kv_cache.unbind(0)
+ ipex_ops.reshape_and_cache_flash( + ipex_ops.reshape_and_cache_flash(
+ key[:num_actual_tokens], + key[:num_actual_tokens],
@ -17087,7 +17191,7 @@ index 000000000..dffc7b367
+ return (attn_metadata, encoder_input_tokens_tensor, + return (attn_metadata, encoder_input_tokens_tensor,
+ encoder_input_positions_tensor) + encoder_input_positions_tensor)
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
index 9d49b4385..78e0c54f2 100644 index 9d49b4385..dc5e95f4e 100644
--- a/vllm/worker/xpu_model_runner.py --- a/vllm/worker/xpu_model_runner.py
+++ b/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py
@@ -5,8 +5,8 @@ import time @@ -5,8 +5,8 @@ import time
@ -17735,15 +17839,17 @@ index 9d49b4385..78e0c54f2 100644
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config) self.model_config)
if max_mm_tokens > 0: if max_mm_tokens > 0:
@@ -461,6 +820,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): @@ -461,6 +820,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
"Computed max_num_seqs (%s) to be less than 1. " "Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr) "Setting it to the minimum value of 1.", expr)
max_num_seqs = 1 max_num_seqs = 1
+ ''' + '''
+ if "phi4mm" in self.model_config.hf_config.model_type:
+ max_num_seqs = 1
batch_size = 0 batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
@@ -479,11 +839,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): @@ -479,11 +841,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq_data={group_id: dummy_data.seq_data}, seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params, sampling_params=sampling_params,
block_tables=None, block_tables=None,
@ -17759,7 +17865,7 @@ index 9d49b4385..78e0c54f2 100644
finished_requests_ids = [seq.request_id for seq in seqs] finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input( model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids) seqs, finished_requests_ids=finished_requests_ids)
@@ -493,25 +856,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): @@ -493,25 +858,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
batch_size=batch_size, batch_size=batch_size,
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
@ -17810,7 +17916,7 @@ index 9d49b4385..78e0c54f2 100644
"""Helper method to prepare the model input based on a given sequence """Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling. metadata for possible additional steps, e.g., sampling.
@@ -524,6 +901,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): @@ -524,6 +903,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
return builder.build() # type: ignore return builder.build() # type: ignore
@ -17833,7 +17939,7 @@ index 9d49b4385..78e0c54f2 100644
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -563,6 +956,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): @@ -563,6 +958,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
raise ValueError( raise ValueError(
"XPUModelRunner does not support multi-step execution.") "XPUModelRunner does not support multi-step execution.")
@ -17846,7 +17952,7 @@ index 9d49b4385..78e0c54f2 100644
model_executable = self.model model_executable = self.model
if (self.observability_config is not None if (self.observability_config is not None
and self.observability_config.collect_model_forward_time): and self.observability_config.collect_model_forward_time):
@@ -612,3 +1011,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): @@ -612,3 +1013,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
output.model_forward_time = model_forward_time output.model_forward_time = model_forward_time
return [output] return [output]