update vllm patch (#13211)

Co-authored-by: gc-fu <guancheng.fu@intel.com>
This commit is contained in:
Shaojun Liu 2025-06-06 17:20:45 +08:00 committed by GitHub
parent ac04992278
commit 5a629ae470
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -10389,7 +10389,7 @@ index bd52fc90b..7d4e3555a 100644
if capability < quant_config.get_min_capability():
raise ValueError(
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 89c9b6747..a5be57ce0 100644
index 89c9b6747..feba4f69f 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -210,6 +210,8 @@ class EngineArgs:
@ -10420,7 +10420,7 @@ index 89c9b6747..a5be57ce0 100644
parser.add_argument(
"--disable-cascade-attn",
action="store_true",
@@ -1061,6 +1075,8 @@ class EngineArgs:
@@ -1061,10 +1075,16 @@ class EngineArgs:
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
@ -10429,7 +10429,26 @@ index 89c9b6747..a5be57ce0 100644
)
def create_load_config(self) -> LoadConfig:
@@ -1504,12 +1520,13 @@ class EngineArgs:
+ use_low_bit_loader = False
+
+ if self.low_bit_model_path is not None:
+ use_low_bit_loader = True
if(self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
@@ -1079,8 +1099,10 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
+ use_low_bit_loader=use_low_bit_loader,
)
+
def create_speculative_config(
self,
target_model_config: ModelConfig,
@@ -1504,12 +1526,13 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False
@ -12669,6 +12688,23 @@ index c190a4585..dda2a96cc 100644
boi = self.boi.expand(x.shape[0], -1, -1)
eoi = self.eoi.expand(x.shape[0], -1, -1)
x = torch.cat((boi, x, eoi), dim=1)
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index cb0379c10..5e8b22ab0 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -144,8 +144,10 @@ class Idefics2VisionAttention(nn.Module):
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
- self.attn = MultiHeadAttention(self.num_heads_per_partition,
- self.head_dim, self.scale)
+ # self.attn = MultiHeadAttention(self.num_heads_per_partition,
+ # self.head_dim, self.scale)
+ from vllm.model_executor.models.siglip import SelfAttention
+ self.attn = SelfAttention(self.num_heads_per_partition, self.head_dim, self.scale)
def forward(
self,
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 5fab9df3f..f8e6fbe24 100644
--- a/vllm/model_executor/models/minicpmv.py
@ -13552,6 +13588,18 @@ index 000000000..d96085f46
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/vllm/model_executor/models/phi4mm_audio.py b/vllm/model_executor/models/phi4mm_audio.py
index db90848f9..5eabcf653 100644
--- a/vllm/model_executor/models/phi4mm_audio.py
+++ b/vllm/model_executor/models/phi4mm_audio.py
@@ -230,6 +230,7 @@ class ConformerEncoderLayer(nn.Module):
x = x + 0.5 * self.feed_forward_in(x)
norm_x = self.layer_norm_att(x)
+ mask = mask.to(x.device)
x = x + self.self_attn(
norm_x,
norm_x,
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index c4d02e5dd..2831a5a12 100644
--- a/vllm/model_executor/models/qwen2.py
@ -13589,41 +13637,85 @@ index c4d02e5dd..2831a5a12 100644
)
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 1e6ff1fec..e2480326a 100644
index 1e6ff1fec..90ebe5ca9 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -304,6 +304,10 @@ class Qwen2_5_VisionAttention(nn.Module):
@@ -302,23 +302,33 @@ class Qwen2_5_VisionAttention(nn.Module):
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
+ head_dim = q.shape[-1]
+ import math
+ import xe_addons
+ scale = 1 / math.sqrt(head_dim)
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
@@ -312,10 +316,16 @@ class Qwen2_5_VisionAttention(nn.Module):
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
for x in [q_i, k_i, v_i])
- # Execute attention entry by entry for speed & less VRAM.
- outputs = []
- for i in range(1, len(cu_seqlens)):
- start_idx = cu_seqlens[i - 1]
- end_idx = cu_seqlens[i]
- q_i = q[:, start_idx:end_idx]
- k_i = k[:, start_idx:end_idx]
- v_i = v[:, start_idx:end_idx]
- q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d")
- for x in [q_i, k_i, v_i])
- output_i = F.scaled_dot_product_attention(q_i,
- k_i,
- v_i,
- dropout_p=0.0)
+ # output_i = F.scaled_dot_product_attention(q_i,
+ # k_i,
+ # v_i,
+ # dropout_p=0.0)
+ output_i = xe_addons.sdp_non_causal(
+ q_i.contiguous(),
+ k_i.contiguous(),
+ v_i.contiguous(),
+ None,
+ scale)
output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
- output_i = rearrange(output_i, "b h s d -> b s h d ")
- outputs.append(output_i)
- context_layer = torch.cat(outputs, dim=1)
+ # TODO(xiangyu): Maybe add attn_backend xpu?
+ q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
+ from vllm._ipex_ops import ipex_ops
+ output = torch.empty(
+ (q.shape[0], q.shape[1], q.shape[2]),
+ dtype=q.dtype,
+ device=q.device)
+ import math
+ head_dim = q.shape[-1]
+ scale = 1 / math.sqrt(head_dim)
+ ipex_ops.varlen_attention(q, k, v, output,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ pdropout=0,
+ softmax_scale=scale,
+ zero_tensors=False,
+ is_causal=False,
+ return_softmax=False,
+ gen_=None,
+ logits_soft_cap=0
+ )
+
+ context_layer = rearrange(output,
+ "(b s) ... -> b s ...",
+ b=batch_size)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -613,10 +623,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
- if self.attn_backend == _Backend.FLASH_ATTN:
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- elif self.attn_backend == _Backend.XFORMERS:
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ # if self.attn_backend == _Backend.FLASH_ATTN:
+ # max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ # elif self.attn_backend == _Backend.XFORMERS:
+ # seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
return max_seqlen, seqlens
def forward(
@@ -1082,7 +1093,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input=image_input,
video_input=video_input)
input_ids = None
-
+
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index a7800d415..26af87512 100644
--- a/vllm/model_executor/models/qwen2_vl.py
@ -15133,10 +15225,10 @@ index c271f438e..cf7180606 100755
assert sliding_window == (-1, -1), (
diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py
new file mode 100644
index 000000000..f4a435eaa
index 000000000..964696cfe
--- /dev/null
+++ b/vllm/v1/attention/backends/ipex_attn.py
@@ -0,0 +1,392 @@
@@ -0,0 +1,404 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
+
@ -15152,6 +15244,10 @@ index 000000000..f4a435eaa
+from vllm.attention.backends.ipex_attn import use_gqa_kernel
+from vllm.utils import is_bmg_platform
+import os
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class IPEXAttentionMetadata(FlashAttentionMetadata):
@ -15246,6 +15342,12 @@ index 000000000..f4a435eaa
+ "are not implemented for "
+ "IpexAttnBackendImpl")
+
+ flag = os.getenv("IPEX_LLM_PREFILL_VARLEN_BACKEND", None)
+ self.ipex_varlen_attn = False
+ if flag is not None:
+ self.ipex_varlen_attn = True
+ logger.info_once(f"V1 engine using varlen_attention for prefilling.")
+
+ def forward(
+ self,
+ layer: AttentionLayer,
@ -15293,6 +15395,7 @@ index 000000000..f4a435eaa
+ self.sliding_window,
+ self.alibi_slopes,
+ self.logits_soft_cap,
+ self.ipex_varlen_attn,
+ )
+ return output.view(-1, self.num_heads * self.head_size)
+
@ -15367,6 +15470,7 @@ index 000000000..f4a435eaa
+ sliding_window: Optional[List[int]] = None,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ logits_soft_cap: Optional[float] = None,
+ flag: Optional[bool] = False,
+) -> None:
+ context = get_forward_context()
+ current_metadata = context.attn_metadata
@ -15382,7 +15486,7 @@ index 000000000..f4a435eaa
+ key = key.view(-1, num_kv_heads, head_size)
+ value = value.view(-1, num_kv_heads, head_size)
+
+ if is_bmg_platform:
+ if flag or is_bmg_platform:
+ key_cache, value_cache = kv_cache.unbind(0)
+ ipex_ops.reshape_and_cache_flash(
+ key[:num_actual_tokens],
@ -17087,7 +17191,7 @@ index 000000000..dffc7b367
+ return (attn_metadata, encoder_input_tokens_tensor,
+ encoder_input_positions_tensor)
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
index 9d49b4385..78e0c54f2 100644
index 9d49b4385..dc5e95f4e 100644
--- a/vllm/worker/xpu_model_runner.py
+++ b/vllm/worker/xpu_model_runner.py
@@ -5,8 +5,8 @@ import time
@ -17735,15 +17839,17 @@ index 9d49b4385..78e0c54f2 100644
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
self.model_config)
if max_mm_tokens > 0:
@@ -461,6 +820,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
@@ -461,6 +820,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
+ '''
+ if "phi4mm" in self.model_config.hf_config.model_type:
+ max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
@@ -479,11 +839,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
@@ -479,11 +841,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
seq_data={group_id: dummy_data.seq_data},
sampling_params=sampling_params,
block_tables=None,
@ -17759,7 +17865,7 @@ index 9d49b4385..78e0c54f2 100644
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
@@ -493,25 +856,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
@@ -493,25 +858,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
@ -17810,7 +17916,7 @@ index 9d49b4385..78e0c54f2 100644
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
@@ -524,6 +901,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
@@ -524,6 +903,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
return builder.build() # type: ignore
@ -17833,7 +17939,7 @@ index 9d49b4385..78e0c54f2 100644
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -563,6 +956,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
@@ -563,6 +958,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
raise ValueError(
"XPUModelRunner does not support multi-step execution.")
@ -17846,7 +17952,7 @@ index 9d49b4385..78e0c54f2 100644
model_executable = self.model
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
@@ -612,3 +1011,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
@@ -612,3 +1013,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
output.model_forward_time = model_forward_time
return [output]