update patches (#13290)

Signed-off-by: liu-shaojun <shaojun.liu@intel.com>
This commit is contained in:
Shaojun Liu 2025-08-14 10:15:48 +08:00 committed by GitHub
parent 9cfdf143a2
commit cac90a9238
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -9200,7 +9200,7 @@ index c3d210c27..6a9c7c798 100644
+ max_seq_length, slice_offset,
+ slice_size, add_inputs)
diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py
index d3c61ea26..7d7adad15 100644
index d3c61ea26..8e1219180 100644
--- a/vllm/attention/backends/ipex_attn.py
+++ b/vllm/attention/backends/ipex_attn.py
@@ -5,7 +5,7 @@ from dataclasses import dataclass
@ -9723,7 +9723,7 @@ index d3c61ea26..7d7adad15 100644
def forward(
self,
layer: AttentionLayer,
@@ -195,84 +566,224 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
@@ -195,84 +566,236 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
@ -9948,6 +9948,17 @@ index d3c61ea26..7d7adad15 100644
+ self.logits_soft_cap,
+ self.scale).squeeze(0).movedim(
+ query.dim() - 2, 0)
+ elif attn_type != AttentionType.DECODER:
+ import xe_addons
+ if mask is not None:
+ mask = mask.unsqueeze(0)
+ sub_out = xe_addons.sdp_non_causal(
+ query[None, :, start_q:end_q, :].contiguous(),
+ key[None, :, start_kv:end_kv, :].contiguous(),
+ value[None, :, start_kv:end_kv, :].contiguous(),
+ mask,
+ scale).squeeze(0).movedim(
+ query.dim() - 2, 0)
+ else:
+ sub_out = torch.nn.functional.scaled_dot_product_attention(
+ query[None, :, start_q:end_q, :],
@ -9965,6 +9976,7 @@ index d3c61ea26..7d7adad15 100644
# prefix-enabled attention
- raise RuntimeError(
- "IPEX backend doesn't support prefix decoding.")
+ query = query[:num_prefill_tokens]
+ if self.num_kv_heads != self.num_heads:
+ key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
+ value = value.repeat_interleave(self.num_queries_per_kv,
@ -9981,8 +9993,8 @@ index d3c61ea26..7d7adad15 100644
+ out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item(), torch.amax(query_len).item())
+ else:
+ out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item())
+ assert output[:num_prefill_query_tokens].shape == out.shape
+ output[:num_prefill_query_tokens] = out
+ assert output[:num_prefill_tokens].shape == out.shape
+ output[:num_prefill_tokens] = out
- else:
+ if decode_meta := attn_metadata.decode_metadata:
@ -10004,7 +10016,7 @@ index d3c61ea26..7d7adad15 100644
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
@@ -281,59 +792,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
@@ -281,59 +804,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory
# shortage.
@ -10136,7 +10148,7 @@ index d3c61ea26..7d7adad15 100644
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
@@ -386,3 +924,114 @@ def _make_sliding_window_bias(
@@ -386,3 +936,114 @@ def _make_sliding_window_bias(
attn_biases.append(mask.to(dtype))
return attn_biases
@ -13637,18 +13649,28 @@ index c4d02e5dd..2831a5a12 100644
)
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
index 1e6ff1fec..90ebe5ca9 100644
index 1e6ff1fec..efd659075 100644
--- a/vllm/model_executor/models/qwen2_5_vl.py
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -302,23 +302,33 @@ class Qwen2_5_VisionAttention(nn.Module):
@@ -302,23 +302,37 @@ class Qwen2_5_VisionAttention(nn.Module):
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
- # Execute attention entry by entry for speed & less VRAM.
- outputs = []
- for i in range(1, len(cu_seqlens)):
- start_idx = cu_seqlens[i - 1]
- end_idx = cu_seqlens[i]
+ # TODO(xiangyu): Maybe add attn_backend xpu?
outputs = []
+ head_dim = q.shape[-1]
+ import math
+ import xe_addons
+ head_dim = q.shape[-1]
+ scale = 1 / math.sqrt(head_dim)
+ q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
- q_i = q[:, start_idx:end_idx]
- k_i = k[:, start_idx:end_idx]
- v_i = v[:, start_idx:end_idx]
@ -13659,39 +13681,28 @@ index 1e6ff1fec..90ebe5ca9 100644
- v_i,
- dropout_p=0.0)
- output_i = rearrange(output_i, "b h s d -> b s h d ")
- outputs.append(output_i)
+ q_i = q[:, :, start_idx:end_idx]
+ k_i = k[:, :, start_idx:end_idx]
+ v_i = v[:, :, start_idx:end_idx]
+ # output_i = F.scaled_dot_product_attention(q_i,
+ # k_i,
+ # v_i,
+ # dropout_p=0.0)
+ output_i = xe_addons.sdp_non_causal(
+ q_i,
+ k_i,
+ v_i,
+ None,
+ scale)
+ # output_i = rearrange(output_i, "b h s d -> b s h d ")
outputs.append(output_i)
- context_layer = torch.cat(outputs, dim=1)
+ # TODO(xiangyu): Maybe add attn_backend xpu?
+ q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
+ from vllm._ipex_ops import ipex_ops
+ output = torch.empty(
+ (q.shape[0], q.shape[1], q.shape[2]),
+ dtype=q.dtype,
+ device=q.device)
+ import math
+ head_dim = q.shape[-1]
+ scale = 1 / math.sqrt(head_dim)
+ ipex_ops.varlen_attention(q, k, v, output,
+ cu_seqlens,
+ cu_seqlens,
+ max_seqlen,
+ max_seqlen,
+ pdropout=0,
+ softmax_scale=scale,
+ zero_tensors=False,
+ is_causal=False,
+ return_softmax=False,
+ gen_=None,
+ logits_soft_cap=0
+ )
+
+ context_layer = rearrange(output,
+ "(b s) ... -> b s ...",
+ b=batch_size)
+ context_layer = torch.cat(outputs, dim=2)
+ context_layer = rearrange(context_layer, "b h s d -> b s h d")
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
@@ -613,10 +623,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
@@ -613,10 +627,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
max_seqlen, seqlens = None, None
@ -13707,7 +13718,7 @@ index 1e6ff1fec..90ebe5ca9 100644
return max_seqlen, seqlens
def forward(
@@ -1082,7 +1093,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
@@ -1082,7 +1097,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input=image_input,
video_input=video_input)
input_ids = None
@ -17191,7 +17202,7 @@ index 000000000..dffc7b367
+ return (attn_metadata, encoder_input_tokens_tensor,
+ encoder_input_positions_tensor)
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
index 9d49b4385..dc5e95f4e 100644
index 9d49b4385..065a09614 100644
--- a/vllm/worker/xpu_model_runner.py
+++ b/vllm/worker/xpu_model_runner.py
@@ -5,8 +5,8 @@ import time
@ -17526,7 +17537,7 @@ index 9d49b4385..dc5e95f4e 100644
+ slot_mapping_tensor = torch.tensor(slot_mapping,
+ dtype=torch.long,
+ device=self.device)
+ if need_block_table or "bge" in self.runner.model_config.model.lower():
+ if need_block_table or "reranker" in self.runner.model_config.model.lower() or "embedding" in self.runner.model_config.model.lower():
+ seq_lens_tensor = torch.tensor(seq_lens,
+ dtype=torch.int,
+ device=self.device)
@ -18998,7 +19009,7 @@ index 000000000..550bf81e8
+
+ return pooling_metadata
diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py
index 3aea0d741..7421826c3 100644
index 3aea0d741..90ee8e2f6 100644
--- a/vllm/worker/xpu_worker.py
+++ b/vllm/worker/xpu_worker.py
@@ -2,9 +2,10 @@
@ -19041,16 +19052,35 @@ index 3aea0d741..7421826c3 100644
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
@@ -65,7 +75,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
@@ -65,7 +75,26 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[CacheEngine]
- self.gpu_cache: Optional[List[List[torch.Tensor]]]
+ self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
+
+ # Torch profiler. Enabled and configured through env vars:
+ # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
+ import vllm.envs as envs
+ if envs.VLLM_TORCH_PROFILER_DIR:
+ torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
+ logger.info("Profiling enabled. Traces will be saved to: %s",
+ torch_profiler_trace_dir)
+ self.profiler = torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.XPU,
+ ],
+ with_stack=True,
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
+ torch_profiler_trace_dir, use_gzip=True))
+ else:
+ self.profiler = None
+
def init_device(self) -> None:
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
@@ -99,16 +109,74 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
@@ -99,16 +128,74 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
@ -19127,7 +19157,7 @@ index 3aea0d741..7421826c3 100644
total_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
free_gpu_memory = total_gpu_memory - used_memory
@@ -132,6 +200,20 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
@@ -132,6 +219,20 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
num_cpu_blocks = max(num_cpu_blocks, 0)
gc.collect()
torch.xpu.empty_cache()
@ -19148,7 +19178,7 @@ index 3aea0d741..7421826c3 100644
return num_gpu_blocks, num_cpu_blocks
def _warm_up_model(self) -> None:
@@ -177,9 +259,10 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
@@ -177,9 +278,10 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up