From 66eb054988dbc1253dc43475347e98e262446164 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Mon, 19 May 2025 16:54:21 +0800 Subject: [PATCH] Update vllm patch (#13164) --- .../xpu/docker/vllm_for_multi_arc.patch | 1357 +++++++++++++++-- 1 file changed, 1196 insertions(+), 161 deletions(-) diff --git a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch index aa898fc4..1888f041 100644 --- a/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch +++ b/docker/llm/serving/xpu/docker/vllm_for_multi_arc.patch @@ -8692,7 +8692,7 @@ index 000000000..e98db9b65 + tensor_parallel_size=1, + ) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py -index c3d210c27..8dd101608 100644 +index c3d210c27..6a9c7c798 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -1,6 +1,4 @@ @@ -8703,7 +8703,7 @@ index c3d210c27..8dd101608 100644 import torch -@@ -13,6 +11,7 @@ try: +@@ -13,14 +11,15 @@ try: except ImportError as e: logger.warning("Import error msg: %s", e.msg) @@ -8711,6 +8711,16 @@ index c3d210c27..8dd101608 100644 class ipex_ops: + @staticmethod + def _reshape_activation_tensor( + x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +- num = x.size(0) +- d = x.size(1) // 2 ++ num = x.size(-2) ++ d = x.size(-1) // 2 + x = x.reshape(num, 2, d) + x1, x2 = torch.chunk(x, chunks=2, dim=1) + x1 = x1.reshape(num, d) @@ -29,23 +28,31 @@ class ipex_ops: @staticmethod @@ -9190,7 +9200,7 @@ index c3d210c27..8dd101608 100644 + max_seq_length, slice_offset, + slice_size, add_inputs) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py -index d3c61ea26..3ec6ee9ee 100644 +index d3c61ea26..7d7adad15 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -5,7 +5,7 @@ from dataclasses import dataclass @@ -9235,7 +9245,16 @@ index d3c61ea26..3ec6ee9ee 100644 @dataclass -@@ -77,6 +79,11 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): +@@ -74,9 +76,120 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor ++ ++ max_prefill_seq_len: int ++ # Maximum sequence length among decode batch. 0 if there are prefill ++ # requests only. ++ max_decode_seq_len: int ++ seq_lens: Optional[List[int]] seqlen_q: Optional[torch.Tensor] max_seqlen: Optional[int] @@ -9244,10 +9263,116 @@ index d3c61ea26..3ec6ee9ee 100644 + + _cached_prefill_metadata: Optional["IpexAttnMetadata"] = None + _cached_decode_metadata: Optional["IpexAttnMetadata"] = None ++ ++ # Begin encoder attn & enc/dec cross-attn fields... ++ ++ # Encoder sequence lengths representation ++ encoder_seq_lens: Optional[List[int]] = None ++ encoder_seq_lens_tensor: Optional[torch.Tensor] = None ++ # (batch_size + 1,). The cumulative sequence lengths of the sequences in ++ # the batch, used to index into sequence. E.g., if the sequence length is ++ # [4, 6], it is [0, 4, 10]. ++ encoder_seq_start_loc: Optional[torch.Tensor] = None ++ seq_start_loc: Optional[torch.Tensor] = None ++ # Maximum sequence length among encoder sequences ++ max_encoder_seq_len: Optional[int] = None ++ # Number of tokens input to encoder ++ num_encoder_tokens: Optional[int] = None ++ ++ # Cross-attention memory-mapping data structures: slot mapping ++ # and block tables ++ cross_slot_mapping: Optional[torch.Tensor] = None ++ cross_block_tables: Optional[torch.Tensor] = None ++ ++ ++ @property ++ def is_all_encoder_attn_metadata_set(self): ++ ''' ++ All attention metadata required for encoder attention is set. ++ ''' ++ return ((self.encoder_seq_lens is not None) ++ and (self.encoder_seq_lens_tensor is not None) ++ and (self.max_encoder_seq_len is not None)) ++ ++ ++ @property ++ def is_all_cross_attn_metadata_set(self): ++ ''' ++ All attention metadata required for enc/dec cross-attention is set. ++ ++ Superset of encoder attention required metadata. ++ ''' ++ return (self.is_all_encoder_attn_metadata_set ++ and (self.cross_slot_mapping is not None) ++ and (self.cross_block_tables is not None)) ++ ++ ++ def get_attn_bias( ++ self, ++ attn_type: str, ++ ) -> Optional[List[torch.Tensor]]: ++ ''' ++ Extract appropriate attention bias from attention metadata ++ according to attention type. ++ ++ Arguments: ++ ++ * attn_metadata: Attention metadata structure associated with attention ++ * attn_type: encoder attention, decoder self-attention, ++ encoder/decoder cross-attention ++ ++ Returns: ++ * Appropriate attention bias value given the attention type ++ ''' ++ ++ if (attn_type == AttentionType.DECODER ++ or attn_type == AttentionType.ENCODER_ONLY): ++ return self.attn_bias ++ elif attn_type == AttentionType.ENCODER: ++ return self.encoder_attn_bias ++ elif attn_type == AttentionType.ENCODER_DECODER: ++ return self.cross_attn_bias ++ else: ++ raise AttributeError(f"Invalid attention type {str(attn_type)}") ++ ++ ++ def set_attn_bias( ++ self, ++ attn_bias: List[torch.Tensor], ++ attn_type: str, ++ ) -> None: ++ ''' ++ Update appropriate attention bias field of attention metadata, ++ according to attention type. ++ ++ Arguments: ++ ++ * attn_metadata: Attention metadata structure associated with attention ++ * attn_bias: The desired attention bias value ++ * attn_type: encoder attention, decoder self-attention, ++ encoder/decoder cross-attention ++ ''' ++ ++ if (attn_type == AttentionType.DECODER ++ or attn_type == AttentionType.ENCODER_ONLY): ++ self.attn_bias = attn_bias ++ elif attn_type == AttentionType.ENCODER: ++ self.encoder_attn_bias = attn_bias ++ elif attn_type == AttentionType.ENCODER_DECODER: ++ self.cross_attn_bias = attn_bias ++ else: ++ raise AttributeError(f"Invalid attention type {str(attn_type)}") ++ def __post_init__(self): # Set during the execution of the first attention op. -@@ -89,21 +96,143 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): +@@ -85,25 +198,253 @@ class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None ++ self.encoder_attn_bias: Optional[List[torch.Tensor]] = None ++ self.cross_attn_bias: Optional[List[torch.Tensor]] = None + @property def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: # Currently chunked prefill is not supported @@ -9281,10 +9406,18 @@ index d3c61ea26..3ec6ee9ee 100644 + # max_query_len=self.max_query_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1] if (torch.is_tensor(self.query_start_loc)) else None, -+ # seq_start_loc=None, ++ seq_start_loc=self.seq_start_loc[:self.num_prefills + 1] if (torch.is_tensor(self.seq_start_loc)) else None, + context_lens=self.context_lens[:self.num_prefills] if (torch.is_tensor(self.context_lens)) else None, + block_tables=self.block_tables[:self.num_prefills], + enable_kv_scales_calculation=False, ++ # Begin encoder & cross attn fields below... ++ max_prefill_seq_len=self.max_prefill_seq_len, ++ encoder_seq_lens=self.encoder_seq_lens, ++ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, ++ encoder_seq_start_loc=self.encoder_seq_start_loc, ++ max_encoder_seq_len=self.max_encoder_seq_len, ++ cross_slot_mapping=self.cross_slot_mapping, ++ cross_block_tables=self.cross_block_tables + ) + return self._cached_prefill_metadata @@ -9297,7 +9430,6 @@ index d3c61ea26..3ec6ee9ee 100644 return None - return self -- + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None @@ -9318,15 +9450,106 @@ index d3c61ea26..3ec6ee9ee 100644 + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + # seq_start_loc=None, ++ seq_start_loc=self.seq_start_loc[self.num_prefills:], + context_lens=self.context_lens[self.num_prefills:] if (torch.is_tensor(self.context_lens)) else None, + block_tables=self.block_tables[self.num_prefills:], + enable_kv_scales_calculation=False, ++ # Begin encoder & cross attn fields below... ++ max_prefill_seq_len=self.max_prefill_seq_len, ++ encoder_seq_lens=self.encoder_seq_lens, ++ encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, ++ encoder_seq_start_loc=self.encoder_seq_start_loc, ++ max_encoder_seq_len=self.max_encoder_seq_len, ++ cross_slot_mapping=self.cross_slot_mapping, ++ cross_block_tables=self.cross_block_tables + ) + return self._cached_decode_metadata + ++ ++ def get_seq_lens( ++ self, ++ attn_type: str, ++ ): ++ ''' ++ Extract appropriate sequence lengths from attention metadata ++ according to attention type. ++ ++ Arguments: ++ ++ * attn_metadata: Attention metadata structure associated with attention ++ * attn_type: encoder attention, decoder self-attention, ++ encoder/decoder cross-attention ++ ++ Returns: ++ * Appropriate sequence lengths tensor for query ++ * Appropriate sequence lengths tensor for key & value ++ ''' ++ ++ if (attn_type == AttentionType.DECODER ++ or attn_type == AttentionType.ENCODER_ONLY): ++ seq_lens_q = self.seq_lens ++ seq_lens_kv = self.seq_lens ++ elif attn_type == AttentionType.ENCODER: ++ seq_lens_q = self.encoder_seq_lens ++ seq_lens_kv = self.encoder_seq_lens ++ elif attn_type == AttentionType.ENCODER_DECODER: ++ seq_lens_q = self.seq_lens ++ seq_lens_kv = self.encoder_seq_lens ++ else: ++ raise AttributeError(f"Invalid attention type {str(attn_type)}") ++ return seq_lens_q, seq_lens_kv ++ ++ ++ def get_seq_len_block_table_args( ++ self, ++ attn_type: str, ++ ) -> tuple: ++ ''' ++ The particular choice of sequence-length- and block-table-related ++ attributes which should be extracted from attn_metadata is dependent ++ on the type of attention operation. ++ ++ Decoder attn -> select entirely decoder self-attention-related fields ++ Encoder/decoder cross-attn -> select encoder sequence lengths & ++ cross-attn block-tables fields ++ Encoder attn -> select encoder sequence lengths fields & no block tables ++ ++ Arguments: ++ ++ * attn_metadata: Attention metadata structure associated with attention ++ * is_prompt: True if prefill, False otherwise ++ * attn_type: encoder attention, decoder self-attention, ++ encoder/decoder cross-attention + ++ Returns: ++ ++ * Appropriate sequence-lengths tensor ++ * Appropriate max sequence-length scalar ++ * Appropriate block tables (or None) ++ ''' ++ ++ if (attn_type == AttentionType.DECODER ++ or attn_type == AttentionType.ENCODER_ONLY): ++ # Decoder self-attention ++ # Choose max_seq_len based on whether we are in prompt_run ++ return (self.seq_lens_tensor, self.max_decode_seq_len, ++ self.block_tables) ++ elif attn_type == AttentionType.ENCODER_DECODER: ++ # Enc/dec cross-attention KVs match encoder sequence length; ++ # cross-attention utilizes special "cross" block tables ++ return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, ++ self.cross_block_tables) ++ elif attn_type == AttentionType.ENCODER: ++ # No block tables associated with encoder attention ++ return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, ++ None) ++ else: ++ raise AttributeError(f"Invalid attention type {str(attn_type)}") ++ ++ + def advance_step(self, num_seqs, num_queries): + assert num_seqs == num_queries -+ ++ + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs @@ -9383,8 +9606,13 @@ index d3c61ea26..3ec6ee9ee 100644 + + +def use_sdp_causal(head_dim, query_states, logits_soft_cap, attn_type): ++ disabled = os.environ.get('IPEX_LLM_DISABLE_SDP_CAUSAL', None) ++ if disabled is not None: ++ disabled = int(disabled) ++ if disabled == 1: ++ return False + return ( -+ (logits_soft_cap != 0 # for gemma model ++ (logits_soft_cap != 0 # for gemma model + or head_dim in [-1, 64, 80, 96, 128, 256]) # for now + and query_states.device.type == "xpu" # GPU + and query_states.dtype in [torch.float, torch.half] # fp32/fp16 @@ -9400,7 +9628,7 @@ index d3c61ea26..3ec6ee9ee 100644 class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): -@@ -119,6 +248,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -119,6 +460,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, @@ -9408,7 +9636,7 @@ index d3c61ea26..3ec6ee9ee 100644 ) -> None: if blocksparse_params is not None: raise ValueError( -@@ -132,29 +262,40 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -132,29 +474,40 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): self.alibi_slopes = alibi_slopes self.sliding_window = sliding_window self.kv_cache_dtype = kv_cache_dtype @@ -9437,10 +9665,11 @@ index d3c61ea26..3ec6ee9ee 100644 - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " -+ if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: -+ raise NotImplementedError("Encoder/decoder cross-attention " -+ "is not implemented for " - "IpexAttnBackendImpl") +- "IpexAttnBackendImpl") ++ # if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: ++ # raise NotImplementedError("Encoder/decoder cross-attention " ++ # "is not implemented for " ++ # "IpexAttnBackendImpl") + if kv_cache_dtype not in _IPEX_BACKEND_SUPPORTED_KV_CACHE_FORMAT: + raise NotImplementedError(f"IPEX backend does not support " + "KV cache format {kv_cache_dtype}") @@ -9449,7 +9678,7 @@ index d3c61ea26..3ec6ee9ee 100644 + if not self.using_gqa_kernel and kv_cache_dtype == "fp8": + raise NotImplementedError(f"IPEX backend currently only supports " + "fp8 kv cache in group-query attention") -+ ++ + self.ipex_varlen_attn = False + flag = os.getenv("IPEX_LLM_PREFILL_VARLEN_BACKEND", None) + if flag is not None: @@ -9458,7 +9687,7 @@ index d3c61ea26..3ec6ee9ee 100644 def split_kv_cache( self, -@@ -162,16 +303,34 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): +@@ -162,16 +515,34 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): num_kv_heads: int, head_size: int, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -9494,10 +9723,28 @@ index d3c61ea26..3ec6ee9ee 100644 def forward( self, layer: AttentionLayer, -@@ -202,75 +361,177 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - +@@ -195,84 +566,224 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + Returns: + shape = [num_tokens, num_heads * head_size] + """ ++ attn_type = self.attn_type ++ if (attn_type == AttentionType.ENCODER ++ and (not attn_metadata.is_all_encoder_attn_metadata_set)): ++ raise AttributeError("Encoder attention requires setting " ++ "encoder metadata attributes.") ++ elif (attn_type == AttentionType.ENCODER_DECODER ++ and (not attn_metadata.is_all_cross_attn_metadata_set)): ++ raise AttributeError("Encoder/decoder cross-attention " ++ "requires setting cross-attention " ++ "metadata attributes.") ++ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) +- key = key.view(-1, self.num_kv_heads, self.head_size) +- value = value.view(-1, self.num_kv_heads, self.head_size) +- - if kv_cache.numel() > 0: - key_cache, value_cache = self.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) @@ -9516,78 +9763,97 @@ index d3c61ea26..3ec6ee9ee 100644 - assert attn_metadata.seq_lens is not None - if (kv_cache.numel() == 0 - or attn_metadata.block_tables.numel() == 0): -+ if kv_cache.numel() > 0 and self.attn_type == AttentionType.DECODER: ++ if key is not None: ++ key = key.view(-1, self.num_kv_heads, self.head_size) ++ if value is not None: ++ value = value.view(-1, self.num_kv_heads, self.head_size) ++ ++ if kv_cache.numel() > 0 and attn_type != AttentionType.ENCODER: + if self.using_gqa_kernel: + key_cache, value_cache = self.split_kv_cache_ipexllm( -+ kv_cache, self.num_kv_heads, self.head_size) -+ ipex_ops.reshape_and_cache_ipexllm( -+ key, -+ value, -+ key_cache, -+ value_cache, -+ attn_metadata.slot_mapping.flatten(), -+ self.kv_cache_dtype, -+ layer._k_scale, -+ layer._v_scale, -+ ) ++ kv_cache, self.num_kv_heads, self.head_size) + else: + key_cache, value_cache = self.split_kv_cache( -+ kv_cache, self.num_kv_heads, self.head_size) -+ ipex_ops.reshape_and_cache( -+ key, -+ value, -+ key_cache, -+ value_cache, -+ attn_metadata.slot_mapping.flatten(), -+ self.kv_cache_dtype, -+ layer._k_scale, -+ layer._v_scale, -+ ) ++ kv_cache, self.num_kv_heads, self.head_size) ++ if (key is not None) and ( ++ value is not None): ++ if attn_type == AttentionType.ENCODER_DECODER: ++ updated_slot_mapping = attn_metadata.cross_slot_mapping ++ else: ++ updated_slot_mapping = attn_metadata.slot_mapping ++ if self.using_gqa_kernel: ++ ipex_ops.reshape_and_cache_ipexllm( ++ key, ++ value, ++ key_cache, ++ value_cache, ++ updated_slot_mapping.flatten(), ++ self.kv_cache_dtype, ++ layer._k_scale, ++ layer._v_scale, ++ ) ++ else: ++ ipex_ops.reshape_and_cache( ++ key, ++ value, ++ key_cache, ++ value_cache, ++ updated_slot_mapping.flatten(), ++ self.kv_cache_dtype, ++ layer._k_scale, ++ layer._v_scale, ++ ) + -+ # New added code-segment -+ num_prefill_tokens = attn_metadata.num_prefill_tokens -+ num_decode_tokens = attn_metadata.num_decode_tokens -+ assert query.shape[0] == num_prefill_tokens + num_decode_tokens -+ assert key.shape[0] == num_prefill_tokens + num_decode_tokens -+ assert value.shape[0] == num_prefill_tokens + num_decode_tokens ++ if attn_type != AttentionType.ENCODER: ++ # Decoder self-attention supports chunked prefill. ++ # Encoder/decoder cross-attention requires no chunked ++ # prefill (100% prefill or 100% decode tokens, no mix) ++ num_prefill_tokens = attn_metadata.num_prefill_tokens ++ num_decode_tokens = attn_metadata.num_decode_tokens ++ else: ++ # Encoder attention - chunked prefill is not applicable; ++ # derive token-count from query shape & and treat them ++ # as 100% prefill tokens ++ assert attn_metadata.num_encoder_tokens is not None ++ num_prefill_tokens = attn_metadata.num_encoder_tokens ++ num_decode_tokens = 0 + ++ if attn_type == AttentionType.DECODER: ++ # Only enforce this shape-constraint for decoder ++ # self-attention ++ assert key.shape[0] == num_prefill_tokens + num_decode_tokens ++ assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) -+ # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] -+ # QKV for prefill. -+ query = query[:num_prefill_tokens] -+ key = key[:num_prefill_tokens] -+ value = value[:num_prefill_tokens] + -+ assert query.shape[0] == num_prefill_tokens -+ assert decode_query.shape[0] == num_decode_tokens -+ # If mask is not set, then is_causal=True -+ # If mask is set, then is_causal=False + is_causal = not self.need_mask -+ if self.attn_type == AttentionType.ENCODER_ONLY: ++ if attn_type == AttentionType.ENCODER_ONLY: + is_causal = False + + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.seq_lens is not None + if (kv_cache is None or prefill_meta.block_tables.numel() == 0): ++ ++ if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - if attn_metadata.attn_bias is None: -+ if prefill_meta.attn_bias is None: ++ attn_masks = attn_metadata.get_attn_bias(attn_type) ++ if attn_masks is None: if self.alibi_slopes is not None: -+ self.alibi_slopes = self.alibi_slopes.to(query.device) - att_masks = _make_alibi_bias( +- att_masks = _make_alibi_bias( ++ attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, -- attn_metadata.seq_lens) # type: ignore -+ prefill_meta.seq_lens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: - att_masks = _make_sliding_window_bias( -- attn_metadata.seq_lens, self.sliding_window, -+ prefill_meta.seq_lens, self.sliding_window, +- att_masks = _make_sliding_window_bias( ++ assert attn_metadata.seq_lens is not None ++ attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = _make_sliding_window_bias( @@ -9615,15 +9881,16 @@ index d3c61ea26..3ec6ee9ee 100644 - gen_=None, - logits_soft_cap=self.logits_soft_cap, - ) -+ att_masks = [None] * len(prefill_meta.seq_lens) -+ prefill_meta.attn_bias = att_masks -+ ++ seq_lens, _ = attn_metadata.get_seq_lens(attn_type) ++ attn_masks = [None] * len(seq_lens) ++ attn_metadata.set_attn_bias(attn_masks, attn_type) ++ + if self.ipex_varlen_attn: + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) -+ ++ + tmp = [0] + tmp.extend(prefill_meta.seq_lens) + seqlen = torch.tensor(tmp) @@ -9643,58 +9910,61 @@ index d3c61ea26..3ec6ee9ee 100644 + return_softmax=False, + gen_=None, + logits_soft_cap=self.logits_soft_cap) -+ else: -+ output = torch.empty( -+ (num_tokens, self.num_heads, self.head_size), -+ dtype=query.dtype, device=query.device) ++ else: ++ # output = torch.empty( ++ # (num_tokens, self.num_heads, self.head_size), ++ # dtype=query.dtype, device=query.device) + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + import math + scale = 1 / math.sqrt(self.head_size) if self.scale is None else self.scale -+ start = 0 -+ for seq_len, mask in zip(prefill_meta.seq_lens, -+ prefill_meta.attn_bias): -+ end = start + seq_len -+ if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap, self.attn_type): ++ causal_attn = (attn_type == AttentionType.DECODER) ++ seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) ++ start_q, start_kv = 0, 0 ++ ++ for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, ++ attn_masks): ++ end_q = start_q + seq_len_q ++ end_kv = start_kv + seq_len_kv ++ if self.alibi_slopes is None and use_sdp_causal(self.head_size, query, self.logits_soft_cap, attn_type): + import xe_addons + if mask is not None: + mask = mask.unsqueeze(0) + if self.logits_soft_cap == 0 or self.head_size != 256: + sub_out = xe_addons.sdp_causal( -+ query[None, :, start:end, :].contiguous(), -+ key[None, :, start:end, :].contiguous(), -+ value[None, :, start:end, :].contiguous(), ++ query[None, :, start_q:end_q, :].contiguous(), ++ key[None, :, start_kv:end_kv, :].contiguous(), ++ value[None, :, start_kv:end_kv, :].contiguous(), + mask, + scale).squeeze(0).movedim( + query.dim() - 2, 0) + else: + sub_out = xe_addons.gemma2_sdp_causal( -+ query[None, :, start:end, :].contiguous(), -+ key[None, :, start:end, :].contiguous(), -+ value[None, :, start:end, :].contiguous(), ++ query[None, :, start_q:end_q, :].contiguous(), ++ key[None, :, start_kv:end_kv, :].contiguous(), ++ value[None, :, start_kv:end_kv, :].contiguous(), + mask, + self.logits_soft_cap, + self.scale).squeeze(0).movedim( -+ query.dim() - 2, 0) ++ query.dim() - 2, 0) + else: + sub_out = torch.nn.functional.scaled_dot_product_attention( -+ query[None, :, start:end, :], -+ key[None, :, start:end, :], -+ value[None, :, start:end, :], ++ query[None, :, start_q:end_q, :], ++ key[None, :, start_kv:end_kv, :], ++ value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, -+ is_causal=is_causal, ++ is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim( + query.dim() - 2, 0) -+ output[start:end, :, :] = sub_out -+ start = end ++ output[start_q:end_q, :, :] = sub_out ++ start_q, start_kv = end_q, end_kv ++ else: # prefix-enabled attention - raise RuntimeError( - "IPEX backend doesn't support prefix decoding.") -- -- else: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, @@ -9711,9 +9981,10 @@ index d3c61ea26..3ec6ee9ee 100644 + out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item(), torch.amax(query_len).item()) + else: + out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item()) -+ assert output[:num_prefill_tokens].shape == out.shape -+ output[:num_prefill_tokens] = out -+ ++ assert output[:num_prefill_query_tokens].shape == out.shape ++ output[:num_prefill_query_tokens] = out + +- else: + if decode_meta := attn_metadata.decode_metadata: # Decoding run. - max_seq_len = attn_metadata.max_decode_seq_len @@ -9725,8 +9996,15 @@ index d3c61ea26..3ec6ee9ee 100644 + num_seqs, num_heads, head_size = decode_query.shape max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) ++ ( ++ seq_lens_arg, ++ max_seq_len_arg, ++ block_tables_arg, ++ ) = decode_meta.get_seq_len_block_table_args(attn_type) # NOTE(woosuk): We use a simple heuristic to decide whether to use -@@ -281,59 +542,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of +@@ -281,59 +792,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory # shortage. @@ -9810,10 +10088,10 @@ index d3c61ea26..3ec6ee9ee 100644 + value_cache, + self.num_kv_heads, + self.scale, -+ decode_meta.block_tables, -+ decode_meta.seq_lens_tensor, ++ block_tables_arg, ++ seq_lens_arg, + block_size, -+ max_seq_len, ++ max_seq_len_arg, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, @@ -9844,10 +10122,10 @@ index d3c61ea26..3ec6ee9ee 100644 + value_cache, + self.num_kv_heads, + self.scale, -+ decode_meta.block_tables, -+ decode_meta.seq_lens_tensor, ++ block_tables_arg, ++ seq_lens_arg, + block_size, -+ max_seq_len, ++ max_seq_len_arg, + self.alibi_slopes, + self.kv_cache_dtype, + layer._k_scale, @@ -9858,6 +10136,121 @@ index d3c61ea26..3ec6ee9ee 100644 # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) +@@ -386,3 +924,114 @@ def _make_sliding_window_bias( + attn_biases.append(mask.to(dtype)) + + return attn_biases ++ ++def get_num_prefill_decode_query_kv_tokens( ++ attn_metadata, ++ attn_type: str, ++) -> Tuple[int, int, int]: ++ """ ++ Calculate the number of prefill and decode tokens for query, key/value ++ based on the attention metadata and the specified attention type. ++ ++ Args: ++ attn_metadata (FlashAttentionMetadata): Attention Metadata object. ++ attn_type (AttentionType): The type of attention being used. ++ Returns: ++ Tuple[int, int, int]: A tuple containing three integers: ++ - The number of prefill query tokens. ++ - The number of prefill key/value tokens. ++ - The number of decode query tokens. ++ ++ Raises: ++ AssertionError: If the number of encoder tokens in `attn_metadata` ++ is `None` when required for the calculations. ++ """ ++ num_prefill_query_tokens = 0 ++ num_decode_query_tokens = 0 ++ num_prefill_kv_tokens = 0 ++ if attn_type == AttentionType.ENCODER: ++ # Encoder attention is only invoked during prefill phase. ++ # The same input servers a both query and key. ++ assert attn_metadata.num_encoder_tokens is not None ++ num_prefill_query_tokens = attn_metadata.num_encoder_tokens ++ num_prefill_kv_tokens = attn_metadata.num_encoder_tokens ++ num_decode_query_tokens = 0 ++ elif attn_type == AttentionType.ENCODER_DECODER: ++ assert attn_metadata.num_encoder_tokens is not None ++ num_prefill_query_tokens = attn_metadata.num_prefill_tokens ++ # The key is the encoder/cross-attention. ++ num_prefill_kv_tokens = attn_metadata.num_encoder_tokens ++ num_decode_query_tokens = attn_metadata.num_decode_tokens ++ else: # attn_type == AttentionType.DECODER or ++ # attn_type == AttentionType.ENCODER_ONLY ++ num_prefill_query_tokens = attn_metadata.num_prefill_tokens ++ num_prefill_kv_tokens = attn_metadata.num_prefill_tokens ++ num_decode_query_tokens = attn_metadata.num_decode_tokens ++ ++ return (num_prefill_query_tokens, num_prefill_kv_tokens, ++ num_decode_query_tokens) ++ ++def _get_query_key_seq_metadata( ++ attn_metadata, ++ is_prompt: bool, ++ attn_type: str, ++) -> tuple: ++ """ ++ Returns sequence metadata for key and query based on the specified ++ attention type and whether input is a prompt. ++ ++ This function computes the starting locations and maximum sequence lengths ++ for key and query sequences for different attention types. ++ ++ Args: ++ attn_metadata: The attention metadata object ++ is_prompt (bool): A flag indicating if the input is a prompt ++ attn_type (AttentionType): The type of attention being used. ++ ++ Returns: ++ tuple: A tuple containing four integers: ++ - Starting location for the query sequence. ++ - Maximum sequence length for the query sequence. ++ - Starting location for the key sequence. ++ - Maximum sequence length for the key sequence. ++ ++ Raises: ++ AttributeError: If an invalid attention type is provided. ++ """ ++ if attn_type == AttentionType.DECODER: ++ # Decoder self-attention ++ # Choose max_seq_len based on whether we are in prompt_run ++ if is_prompt: ++ max_seq_len = attn_metadata.max_prefill_seq_len ++ else: ++ max_seq_len = attn_metadata.max_decode_seq_len ++ return (attn_metadata.seq_start_loc, max_seq_len, ++ attn_metadata.seq_start_loc, max_seq_len) ++ ++ elif attn_type == AttentionType.ENCODER_DECODER: ++ # This is cross attention between the where the key ++ # is the precomputed encoder attention and query ++ # is the input sequence. ++ # Choose query max length based on whether it is prompt ++ # or not. ++ if is_prompt: ++ max_seq_len = attn_metadata.max_prefill_seq_len ++ else: ++ max_seq_len = attn_metadata.max_decode_seq_len ++ return (attn_metadata.seq_start_loc, max_seq_len, ++ attn_metadata.encoder_seq_start_loc, ++ attn_metadata.max_encoder_seq_len) ++ elif attn_type == AttentionType.ENCODER: ++ # For encoder attention both the query and the key are same i.e the ++ # encoder sequence. ++ return (attn_metadata.encoder_seq_start_loc, ++ attn_metadata.max_encoder_seq_len, ++ attn_metadata.encoder_seq_start_loc, ++ attn_metadata.max_encoder_seq_len) ++ elif attn_type == AttentionType.ENCODER_ONLY: ++ assert is_prompt, "Should not have decode for encoder only model." ++ return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len, ++ attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len) ++ else: ++ raise AttributeError(f"Invalid attention type {str(attn_type)}") ++ diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index 6ab69ea5b..9604f35f6 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py @@ -10178,7 +10571,7 @@ index 669fb96e6..d28984b4b 100644 multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py -index 0579893e5..dfb422ff5 100644 +index 0579893e5..6bd8ad22d 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -330,7 +330,7 @@ class InputRegistry: @@ -10186,7 +10579,7 @@ index 0579893e5..dfb422ff5 100644 from vllm.sequence import SequenceData - if mm_registry.has_processor(model_config): -+ if False and mm_registry.has_processor(model_config): ++ if "whisper" in model_config.model.lower() and mm_registry.has_processor(model_config): processor = mm_registry.create_processor(model_config, disable_cache=True) profiler = MultiModalProfiler(processor) @@ -11914,10 +12307,10 @@ index 1b1738f88..2c2ed67b9 100644 layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py new file mode 100644 -index 000000000..07453f636 +index 000000000..8e6e2c11a --- /dev/null +++ b/vllm/model_executor/models/glm4.py -@@ -0,0 +1,312 @@ +@@ -0,0 +1,318 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Zhipu AI team. @@ -11968,6 +12361,9 @@ index 000000000..07453f636 +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + + ++from vllm.logger import init_logger ++logger = init_logger(__name__) ++ +class Glm4Attention(nn.Module): + + def __init__(self, @@ -12002,7 +12398,7 @@ index 000000000..07453f636 + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads -+ self.rotary_dim = int(partial_rotary_factor * self.head_dim) ++ self.rotary_dim = self.head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 @@ -12031,6 +12427,7 @@ index 000000000..07453f636 + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, ++ dtype=torch.float32, + ) + self.attn = Attention(self.num_heads, + self.head_dim, @@ -12122,6 +12519,8 @@ index 000000000..07453f636 + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) ++ hidden_states = hidden_states.to(torch.float32) ++ hidden_states = torch.clamp(hidden_states, min=-1e36, max=1e36) + hidden_states = self.post_mlp_layernorm(hidden_states) + + return hidden_states, residual @@ -12232,7 +12631,7 @@ index 000000000..07453f636 + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py -index c190a4585..195f48f64 100644 +index c190a4585..dda2a96cc 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -111,8 +111,11 @@ class EVA2CLIPAttention(nn.Module): @@ -12243,12 +12642,22 @@ index c190a4585..195f48f64 100644 - self.scale) + # self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, + # self.scale) -+ from siglip import SelfAttention ++ from vllm.model_executor.models.siglip import SelfAttention + self.attn = SelfAttention(self.num_heads_per_rank, self.head_dim, + self.scale) self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: +@@ -332,7 +335,9 @@ class EVA2CLIPModel(nn.Module): + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) ++ shape = x.shape + x = self.linear_proj(x) ++ x = x.reshape(shape) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5fab9df3f..f8e6fbe24 100644 --- a/vllm/model_executor/models/minicpmv.py @@ -14300,6 +14709,23 @@ index c0a3c59ba..8614c2273 100644 "XverseForCausalLM": ("llama", "LlamaForCausalLM"), "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"), # [Encoder-decoder] +diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py +index a09741a55..d989a12fa 100644 +--- a/vllm/model_executor/models/roberta.py ++++ b/vllm/model_executor/models/roberta.py +@@ -155,8 +155,12 @@ class RobertaClassificationHead(nn.Module): + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) ++ if isinstance(x, tuple): ++ x = x[0] + x = torch.tanh(x) + x = self.out_proj(x) ++ if isinstance(x, tuple): ++ x = x[0] + return x + + diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index cecad9e89..7eaabd1db 100644 --- a/vllm/model_executor/models/siglip.py @@ -16105,8 +16531,540 @@ index 86e6d9752..ad80bf54e 100644 @dataclass(frozen=True) +diff --git a/vllm/worker/xpu_enc_dec_model_runner.py b/vllm/worker/xpu_enc_dec_model_runner.py +new file mode 100644 +index 000000000..dffc7b367 +--- /dev/null ++++ b/vllm/worker/xpu_enc_dec_model_runner.py +@@ -0,0 +1,526 @@ ++# SPDX-License-Identifier: Apache-2.0 ++ ++import dataclasses ++import itertools ++from typing import Any, Dict, List, Optional, Tuple, Type, cast ++ ++import torch ++import torch.distributed ++ ++from vllm.attention.backends.abstract import (AttentionBackend, ++ AttentionMetadata) ++from vllm.attention.backends.utils import PAD_SLOT_ID ++from vllm.attention.selector import (get_env_variable_attn_backend, ++ get_global_forced_attn_backend) ++from vllm.config import VllmConfig ++from vllm.forward_context import set_forward_context ++from vllm.inputs import INPUT_REGISTRY, InputRegistry ++from vllm.logger import init_logger ++from vllm.model_executor import SamplingMetadata ++from vllm.model_executor.layers.sampler import SamplerOutput ++from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, ++ MultiModalRegistry) ++from vllm.platforms import _Backend ++from vllm.sampling_params import SamplingParams ++from vllm.sequence import (IntermediateTensors, PoolerOutput, ++ SequenceGroupMetadata) ++from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad ++from vllm.worker.model_runner import (GPUModelRunnerBase, ++ ModelInputForGPUBuilder) ++from vllm.worker.xpu_model_runner import (XPUModelRunnerBase, ++ ModelInputForXPUBuilder, ++ ModelInputForXPUWithSamplingMetadata) ++from vllm.worker.model_runner_base import ( ++ _add_attn_metadata_broadcastable_dict, ++ _add_sampling_metadata_broadcastable_dict) ++from vllm.worker.utils import assert_enc_dec_mr_supported_scenario ++ ++logger = init_logger(__name__) ++ ++ ++@dataclasses.dataclass(frozen=True) ++class XPUEncoderDecoderModelInput(ModelInputForXPUWithSamplingMetadata): ++ """ ++ Used by the EncoderDecoderModelRunner. ++ """ ++ encoder_input_tokens: Optional[torch.Tensor] = None ++ encoder_input_positions: Optional[torch.Tensor] = None ++ ++ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: ++ tensor_dict = { ++ "input_tokens": self.input_tokens, ++ "input_positions": self.input_positions, ++ "encoder_input_tokens": self.encoder_input_tokens, ++ "encoder_input_positions": self.encoder_input_positions, ++ "virtual_engine": self.virtual_engine, ++ "request_ids_to_seq_ids": self.request_ids_to_seq_ids, ++ "finished_requests_ids": self.finished_requests_ids, ++ "multi_modal_kwargs": self.multi_modal_kwargs, ++ } ++ _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) ++ _add_sampling_metadata_broadcastable_dict(tensor_dict, ++ self.sampling_metadata) ++ return tensor_dict ++ ++ @classmethod ++ def from_broadcasted_tensor_dict( ++ cls, ++ tensor_dict: Dict[str, Any], ++ attn_backend: Optional["AttentionBackend"] = None, ++ ) -> "XPUEncoderDecoderModelInput": ++ return cast( ++ XPUEncoderDecoderModelInput, ++ super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) ++ ++ ++class XPUEncoderDecoderModelRunner(XPUModelRunnerBase[XPUEncoderDecoderModelInput]): ++ _model_input_cls: Type[XPUEncoderDecoderModelInput] = ( ++ XPUEncoderDecoderModelInput) ++ _builder_cls: Type[ModelInputForXPUBuilder] = (ModelInputForXPUBuilder) ++ ++ def __init__( ++ self, ++ vllm_config: VllmConfig, ++ kv_cache_dtype: Optional[str] = "auto", ++ is_driver_worker: bool = False, ++ input_registry: InputRegistry = INPUT_REGISTRY, ++ mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ++ ): ++ ''' ++ EncoderDecoderModelRunner constructor. ++ ++ `lora_config` and `prompt_adapter_config` are ++ unused (since these features are not yet supported for encoder/decoder ++ models) but these arguments are present here for compatibility with ++ the base-class constructor. ++ ''' ++ # self._maybe_force_supported_attention_backend() ++ ++ super().__init__( ++ vllm_config=vllm_config, ++ kv_cache_dtype=kv_cache_dtype, ++ is_driver_worker=is_driver_worker, ++ ) ++ ++ # Crash for unsupported encoder/scenarios ++ assert_enc_dec_mr_supported_scenario(self) ++ ++ def _maybe_force_supported_attention_backend(self): ++ ''' ++ Force vLLM to use the XFormers attention backend, ++ which is currently the only supported option. ++ ''' ++ ++ def raise_backend_err(): ++ # The user has specified an attention backend override ++ # which is invalid for encoder/decoder models ++ raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) ++ ++ maybe_env_var_forced_backend = get_env_variable_attn_backend() ++ maybe_global_forced_backend = get_global_forced_attn_backend() ++ is_forced_by_global = maybe_global_forced_backend is not None ++ is_forced_by_env_var = maybe_env_var_forced_backend is not None ++ if is_forced_by_global: # noqa: SIM102 ++ # Backend override enforced by global variable takes ++ # precedence over vLLM backend environment variable. ++ if maybe_global_forced_backend not in\ ++ [_Backend.XFORMERS, _Backend.FLASH_ATTN]: ++ raise_backend_err() ++ elif is_forced_by_env_var: # noqa: SIM102 ++ # Backend override enforced by vLLM backend ++ # environment variable ++ if maybe_env_var_forced_backend not in\ ++ [_Backend.XFORMERS, _Backend.FLASH_ATTN]: ++ raise_backend_err() ++ ++ def _list_to_int32_tensor( ++ self, ++ _list: List[int], ++ ) -> torch.Tensor: ++ return torch.tensor(_list, dtype=torch.int32, device=self.device) ++ ++ def _list_to_long_tensor( ++ self, ++ _list: List[int], ++ ) -> torch.Tensor: ++ return torch.tensor(_list, dtype=torch.long, device=self.device) ++ ++ def _empty_int32_tensor(self) -> torch.Tensor: ++ return self._list_to_int32_tensor([]) ++ ++ def _empty_long_tensor(self) -> torch.Tensor: ++ return self._list_to_long_tensor([]) ++ ++ @torch.inference_mode() ++ def execute_model( ++ self, ++ model_input: XPUEncoderDecoderModelInput, ++ kv_caches: List[torch.Tensor], ++ intermediate_tensors: Optional[IntermediateTensors] = None, ++ num_steps: int = 1, ++ ) -> Optional[List[PoolerOutput]]: ++ if num_steps > 1: ++ raise ValueError("num_steps > 1 is not supported in " ++ "EncoderDecoderModelRunner") ++ ++ model_executable = self.model ++ ++ # seqlen_agnostic_kwargs = { ++ # "finished_requests_ids": model_input.finished_requests_ids, ++ # "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, ++ # } if self.has_inner_state else {} ++ ++ multi_modal_kwargs = model_input.multi_modal_kwargs or {} ++ with set_forward_context(model_input.attn_metadata, self.vllm_config, ++ model_input.virtual_engine): ++ hidden_or_intermediate_states = model_executable( ++ input_ids=model_input.input_tokens, ++ positions=model_input.input_positions, ++ encoder_input_ids=model_input.encoder_input_tokens, ++ encoder_positions=model_input.encoder_input_positions, ++ intermediate_tensors=intermediate_tensors, ++ **MultiModalKwargs.as_kwargs(multi_modal_kwargs, ++ device=self.device), ++ ) ++ # **seqlen_agnostic_kwargs) ++ ++ logits = self.model.compute_logits(hidden_or_intermediate_states, ++ model_input.sampling_metadata) ++ ++ if not self.is_driver_worker: ++ return [] ++ ++ if model_input.async_callback is not None: ++ model_input.async_callback() ++ ++ # Sample the next token. ++ output: SamplerOutput = self.model.sample( ++ logits=logits, ++ sampling_metadata=model_input.sampling_metadata, ++ ) ++ ++ return [output] ++ ++ def make_model_input_from_broadcasted_tensor_dict( ++ self, tensor_dict: Dict[str, Any]) -> XPUEncoderDecoderModelInput: ++ return XPUEncoderDecoderModelInput.from_broadcasted_tensor_dict( ++ tensor_dict, ++ attn_backend=self.attn_backend, ++ ) ++ ++ def prepare_model_input( ++ self, ++ seq_group_metadata_list: List[SequenceGroupMetadata], ++ virtual_engine: int = 0, ++ finished_requests_ids: Optional[List[str]] = None ++ ) -> XPUEncoderDecoderModelInput: ++ """Prepare the model input based on a given sequence group, including ++ metadata for the sampling step. ++ ++ Since chunked prefill is not supported for encoder/decoder models, ++ `input_tokens` is assumed to be either entirely prefill tokens or ++ entirely decode tokens. ++ ++ """ ++ model_input = self._prepare_model_input_tensors( ++ seq_group_metadata_list, finished_requests_ids) ++ ( ++ attn_metadata, ++ encoder_input_tokens_tensor, ++ encoder_input_positions_tensor, ++ ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, ++ model_input)) ++ # Inject attn_metadata encoder/cross-attention fields & ++ # encoder input tokens/positions into model_input. ++ # Frozen dataclass fields cannot be modified, so use ++ # dataclasses.replace to construct a new model input ++ # instance. ++ model_input = dataclasses.replace( ++ model_input, ++ attn_metadata=attn_metadata, ++ encoder_input_tokens=encoder_input_tokens_tensor, ++ encoder_input_positions=encoder_input_positions_tensor, ++ ) ++ ++ generators = self.get_generators(finished_requests_ids) ++ sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, ++ model_input.seq_lens, ++ model_input.query_lens, ++ self.device, ++ pin_memory=False, ++ generators=generators, ++ cache=self.sampling_metadata_cache) ++ is_prompt = (seq_group_metadata_list[0].is_prompt ++ if seq_group_metadata_list else None) ++ return dataclasses.replace(model_input, ++ sampling_metadata=sampling_metadata, ++ is_prompt=is_prompt, ++ virtual_engine=virtual_engine) ++ ++ @torch.inference_mode() ++ def profile_run(self) -> None: ++ # Enable top-k sampling to reflect the accurate memory usage. ++ sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) ++ max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens ++ max_num_seqs = self.scheduler_config.max_num_seqs ++ ++ # Profile memory usage with max_num_sequences sequences and the total ++ # number of tokens equal to max_num_batched_tokens. ++ seqs: List[SequenceGroupMetadata] = [] ++ ++ max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( ++ self.model_config) ++ if max_mm_tokens > 0: ++ logger.info("Starting profile run for multi-modal models.") ++ ++ batch_size = 0 ++ import os ++ self_max_num_batched_tokens = os.getenv("IPEX_LLM_SELF_MAX_NUM_BATCHED_TOKENS", None) ++ if self_max_num_batched_tokens is not None: ++ max_num_batched_tokens = int(self_max_num_batched_tokens) ++ self_max_num_seqs = os.getenv("IPEX_LLM_SELF_MAX_NUM_SEQS", None) ++ if self_max_num_seqs is not None: ++ max_num_seqs = int(self_max_num_seqs) ++ else: ++ max_num_seqs = 1 ++ for group_id in range(max_num_seqs): ++ seq_len = (max_num_batched_tokens // max_num_seqs + ++ (group_id < max_num_batched_tokens % max_num_seqs)) ++ batch_size += seq_len ++ ++ decoder_dummy_data = self.input_registry \ ++ .dummy_data_for_profiling(self.model_config, ++ seq_len, ++ self.mm_registry, ++ is_encoder_data=False) ++ encoder_dummy_data = self.input_registry \ ++ .dummy_data_for_profiling(self.model_config, ++ seq_len, ++ self.mm_registry, ++ is_encoder_data=True) ++ ++ # Having more tokens is over-conservative but otherwise fine ++ assert len( ++ decoder_dummy_data.seq_data.prompt_token_ids ++ ) >= seq_len, ( ++ f"Expected at least {seq_len} dummy tokens for profiling, " ++ f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" ++ ) ++ ++ assert decoder_dummy_data.multi_modal_data is None or \ ++ encoder_dummy_data.multi_modal_data is None, ( ++ "Multi-modal data can't be provided in both encoder and decoder" ++ ) ++ ++ seq = SequenceGroupMetadata( ++ request_id=str(group_id), ++ is_prompt=True, ++ seq_data={group_id: decoder_dummy_data.seq_data}, ++ sampling_params=sampling_params, ++ block_tables=None, ++ encoder_seq_data=encoder_dummy_data.seq_data, ++ cross_block_table=None, ++ multi_modal_data=decoder_dummy_data.multi_modal_data ++ or encoder_dummy_data.multi_modal_data, ++ multi_modal_placeholders=decoder_dummy_data. ++ multi_modal_placeholders ++ or encoder_dummy_data.multi_modal_placeholders) ++ seqs.append(seq) ++ ++ finished_requests_ids = [seq.request_id for seq in seqs] ++ model_input = self.prepare_model_input( ++ seqs, finished_requests_ids=finished_requests_ids) ++ intermediate_tensors = None ++ ++ num_layers = self.model_config.get_num_layers(self.parallel_config) ++ kv_caches = [None] * num_layers ++ ++ self.execute_model(model_input, kv_caches, intermediate_tensors) ++ torch.xpu.synchronize() ++ return ++ ++ def _prepare_encoder_model_input_tensors( ++ self, ++ seq_group_metadata_list: List[SequenceGroupMetadata], ++ model_input: XPUEncoderDecoderModelInput, ++ ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], ++ Optional[torch.Tensor]]: ++ """Helper method to prepare the encoder- and cross-attn-related ++ model inputs based on a given sequence group. These additional inputs ++ are used to augment an already-computed `XPUEncoderDecoderModelInput` ++ data structure which already has decoder-related model inputs ++ populated. ++ ++ Sets the following attn_metadata fields: ++ * `num_encoder_tokens` ++ * `encoder_seq_lens` ++ * `encoder_seq_lens_tensor` ++ * `max_encoder_seq_len` ++ * `cross_slot_mapping` ++ * `cross_block_tables` ++ ++ Constructs a new model inputs data structure, based on ++ (1) the existing fields in the `model_inputs` argument, ++ and (2) the following additional fields which are ++ computed (or in the case of `attn_metadata`, updated) ++ by this function: ++ * attn_metadata ++ * encoder_input_tokens ++ * encoder_input_positions ++ ++ Arguments: ++ ++ * seq_group_metadata_list: list of sequence groups for which to ++ compute inputs ++ * model_inputs: model inputs data structure with decoder-oriented ++ fields already computed. ++ ++ Return: ++ ++ * Updated model inputs data structure ++ """ ++ ++ if len(seq_group_metadata_list) == 0: ++ return (model_input.attn_metadata, None, None) ++ ++ # Since we are not supporting chunked prefill either the entire ++ # batch is prefill or it is decode ++ is_prompt = seq_group_metadata_list[0].is_prompt ++ ++ # Build encoder inputs ++ encoder_seq_lens: List[int] = [] ++ if is_prompt: ++ # Prefill phase. ++ cross_block_tables = self._empty_int32_tensor().view( ++ len(seq_group_metadata_list), -1) ++ ++ # Extract input tokens/positions, cross-attention slot-mapping, ++ # & seq len from each sequence group metadata ++ ( ++ encoder_input_tokens, ++ encoder_input_positions, ++ cross_slot_mapping, ++ ) = ( ++ [], ++ [], ++ [], ++ ) ++ for seq_group_metadata in seq_group_metadata_list: ++ # Build seq lens ++ seq_len = seq_group_metadata.encoder_seq_data.get_len() ++ token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() ++ encoder_seq_lens.append(seq_len) ++ ++ # Build slot mapping ++ is_profile_run = (seq_group_metadata.block_tables is None) ++ if is_profile_run: ++ # During memory profiling, the block tables are not ++ # initialized yet. In this case, we just use a dummy ++ # slot mapping. ++ # In embeddings, the block tables are {seq_id: None}. ++ cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) ++ else: ++ for i in range(0, seq_len): ++ block_number = seq_group_metadata.cross_block_table[ ++ i // self.block_size] ++ block_offset = i % self.block_size ++ slot = block_number * self.block_size + block_offset ++ cross_slot_mapping.append(slot) ++ ++ # Build encoder input tokens ++ encoder_input_tokens.extend(token_ids) ++ encoder_input_positions.extend(list(range(0, seq_len))) ++ ++ # Convert tokens/positions & cross-attention ++ # slot-mapping to encoder input tensors ++ encoder_input_tokens_tensor = self._list_to_long_tensor( ++ encoder_input_tokens) ++ encoder_input_positions_tensor = self._list_to_long_tensor( ++ encoder_input_positions) ++ cross_slot_mapping_tensor = self._list_to_long_tensor( ++ cross_slot_mapping) ++ ++ else: ++ # Decode phase. ++ encoder_input_tokens_tensor = self._empty_long_tensor() ++ encoder_input_positions_tensor = self._empty_long_tensor() ++ cross_slot_mapping_tensor = self._empty_long_tensor() ++ # Extract cross-attention block tables & ++ # seq len from each sequence group metadata. ++ # Cross-attention block tables are empty ++ # during vLLM memory profiling. ++ cross_block_tables = [] ++ for seq_group_metadata in seq_group_metadata_list: ++ for _ in range(len(seq_group_metadata.seq_data)): ++ encoder_seq_lens.append( ++ seq_group_metadata.encoder_seq_data.get_len()) ++ cross_block_table = seq_group_metadata.cross_block_table ++ cross_block_tables.append([] if ( ++ cross_block_table is None) else cross_block_table) ++ ++ # if (model_input.attn_metadata is not None ++ # and model_input.attn_metadata.use_cuda_graph and False): ++ if False: ++ # We will be using CUDA graph replay for this decode. ++ max_len_of_block_table = self.get_max_block_per_batch() ++ batch_size = len(encoder_seq_lens) ++ graph_batch_size = self.vllm_config.pad_for_cudagraph( ++ batch_size) ++ assert graph_batch_size >= batch_size ++ cuda_graph_pad_size = graph_batch_size - batch_size ++ # extend the cross_block_tables and encoder_seq_lens to match ++ # the graph_batch_size. ++ cross_block_tables.extend([[] ++ for _ in range(cuda_graph_pad_size) ++ ]) ++ encoder_seq_lens.extend( ++ itertools.repeat(1, cuda_graph_pad_size)) ++ ++ else: ++ max_len_of_block_table = max( ++ len(block_table) for block_table in cross_block_tables) ++ ++ cross_block_tables = make_tensor_with_pad( ++ cross_block_tables, ++ max_len=max_len_of_block_table, ++ pad=0, ++ dtype=torch.int32, ++ device=self.device, ++ ) ++ ++ # Compute encoder sequence lengths & encoder ++ # sequence starting offset tensors ++ max_encoder_seq_len = max(encoder_seq_lens, default=0) ++ encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) ++ encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + ++ 1, ++ dtype=torch.int32, ++ device=self.device) ++ torch.cumsum(encoder_seq_lens_tensor, ++ dim=0, ++ dtype=encoder_seq_start_loc.dtype, ++ out=encoder_seq_start_loc[1:]) ++ ++ # Update attention metadata with encoder-oriented attributes ++ attn_metadata = model_input.attn_metadata ++ assert attn_metadata is not None ++ ( ++ attn_metadata.num_encoder_tokens, ++ attn_metadata.encoder_seq_lens, ++ attn_metadata.encoder_seq_lens_tensor, ++ attn_metadata.max_encoder_seq_len, ++ attn_metadata.encoder_seq_start_loc, ++ attn_metadata.cross_slot_mapping, ++ attn_metadata.cross_block_tables, ++ ) = ( ++ sum(encoder_seq_lens), ++ encoder_seq_lens, ++ encoder_seq_lens_tensor, ++ max_encoder_seq_len, ++ encoder_seq_start_loc, ++ cross_slot_mapping_tensor, ++ cross_block_tables, ++ ) ++ ++ return (attn_metadata, encoder_input_tokens_tensor, ++ encoder_input_positions_tensor) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py -index 9d49b4385..7396b0c89 100644 +index 9d49b4385..78e0c54f2 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -5,8 +5,8 @@ import time @@ -16163,7 +17121,7 @@ index 9d49b4385..7396b0c89 100644 TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") -@@ -51,6 +63,8 @@ class ModelInputForXPU(ModelRunnerInputBase): +@@ -51,8 +63,12 @@ class ModelInputForXPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None @@ -16171,8 +17129,12 @@ index 9d49b4385..7396b0c89 100644 + lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None ++ request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None ++ finished_requests_ids: Optional[List[str]] = None virtual_engine: Optional[int] = None -@@ -62,6 +76,9 @@ class ModelInputForXPU(ModelRunnerInputBase): + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None +@@ -62,6 +78,9 @@ class ModelInputForXPU(ModelRunnerInputBase): tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, @@ -16182,7 +17144,13 @@ index 9d49b4385..7396b0c89 100644 } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) -@@ -90,6 +107,9 @@ class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): +@@ -85,11 +104,15 @@ class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None ++ is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, @@ -16192,7 +17160,7 @@ index 9d49b4385..7396b0c89 100644 } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, -@@ -112,7 +132,7 @@ class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): +@@ -112,7 +135,7 @@ class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): def __init__(self, @@ -16201,7 +17169,7 @@ index 9d49b4385..7396b0c89 100644 finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.runner = runner -@@ -121,6 +141,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -121,6 +144,10 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.device = self.runner.device @@ -16212,7 +17180,7 @@ index 9d49b4385..7396b0c89 100644 def prepare(self, finished_requests_ids: Optional[List[str]] = None) -> None: -@@ -130,33 +154,275 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -130,33 +157,283 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): self.seq_group_metadata_list.append(seq_group_metadata) def build(self) -> ModelInputForXPU: @@ -16455,8 +17423,14 @@ index 9d49b4385..7396b0c89 100644 + else: + context_lens_tensor = context_lens + query_start_loc = query_lens - ++ + # Generate attn_metadata ++ from itertools import accumulate ++ seq_start_loc = list(accumulate(seq_lens, initial=0)) ++ seq_start_loc = torch.tensor(seq_start_loc, ++ dtype=torch.int, ++ device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + # FIXME: Later maybe we can get rid of this parameter + is_prompt=is_prompt, #1 @@ -16472,9 +17446,11 @@ index 9d49b4385..7396b0c89 100644 + max_seqlen=max(query_lens), + seq_lens_tensor=seq_lens_tensor, # 9 + # max_query_len=max_query_len, ++ #max_prefill_seq_len=0 if is_prompt else max(seq_lens), ++ max_prefill_seq_len=max(prefill_seq_lens) if is_prompt else 0, + max_decode_seq_len=max_decode_seq_len, # 10 + query_start_loc=query_start_loc, -+ # seq_start_loc=seq_start_loc, ++ seq_start_loc=seq_start_loc, + context_lens=context_lens_tensor, + block_tables=block_tables if need_block_table else torch.tensor([], device=self.device, dtype=torch.int) # 11 + ) @@ -16503,7 +17479,7 @@ index 9d49b4385..7396b0c89 100644 assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] -@@ -166,6 +432,9 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -166,6 +443,9 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): multi_modal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -16513,7 +17489,7 @@ index 9d49b4385..7396b0c89 100644 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt -@@ -184,29 +453,55 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -184,29 +464,55 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -16591,7 +17567,7 @@ index 9d49b4385..7396b0c89 100644 if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized -@@ -276,26 +571,34 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -276,26 +582,34 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -16629,7 +17605,7 @@ index 9d49b4385..7396b0c89 100644 for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] -@@ -315,6 +618,8 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -315,6 +629,8 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) @@ -16638,7 +17614,7 @@ index 9d49b4385..7396b0c89 100644 if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // -@@ -359,17 +664,14 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): +@@ -359,17 +675,14 @@ class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): num_prefills=0, block_tables=block_tables, ) @@ -16662,7 +17638,7 @@ index 9d49b4385..7396b0c89 100644 def __init__( self, -@@ -410,6 +712,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -410,6 +723,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -16675,7 +17651,7 @@ index 9d49b4385..7396b0c89 100644 self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() \ -@@ -432,6 +740,15 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -432,12 +751,51 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): def vocab_size(self) -> int: return self.model_config.get_vocab_size() @@ -16689,12 +17665,19 @@ index 9d49b4385..7396b0c89 100644 + return rope_scaling.get("type", None) == "mrope" or rope_scaling.get("mrope_section", None) is not None + @torch.inference_mode() - def profile_run(self) -> None: +- def profile_run(self) -> None: ++ def profile_run(self, num_batched_tokens=-1, num_seqs=-1) -> None: # Enable top-k sampling to reflect the accurate memory usage. -@@ -439,6 +756,30 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens ++ assert (num_batched_tokens == -1 or num_batched_tokens > 0) ++ assert (num_seqs == -1 or num_seqs > 0) max_num_seqs = self.scheduler_config.max_num_seqs - ++ if num_batched_tokens != -1: ++ max_num_batched_tokens = num_batched_tokens ++ if num_seqs != -1: ++ max_num_seqs = num_seqs ++ + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request @@ -16718,11 +17701,10 @@ index 9d49b4385..7396b0c89 100644 + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] -+ + # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] -@@ -448,6 +789,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -448,6 +806,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): # To exercise the worst scenario for GPU memory consumption, # the number of seqs (batch_size) is chosen to maximize the number # of images processed. @@ -16730,26 +17712,15 @@ index 9d49b4385..7396b0c89 100644 max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: -@@ -461,8 +803,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -461,6 +820,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): "Computed max_num_seqs (%s) to be less than 1. " "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 + ''' batch_size = 0 -+ import os -+ self_max_num_batched_tokens = os.getenv("IPEX_LLM_SELF_MAX_NUM_BATCHED_TOKENS", None) -+ if self_max_num_batched_tokens is not None: -+ max_num_batched_tokens = int(self_max_num_batched_tokens) -+ self_max_num_seqs = os.getenv("IPEX_LLM_SELF_MAX_NUM_SEQS", None) -+ if self_max_num_seqs is not None: -+ max_num_seqs = int(self_max_num_seqs) -+ else: -+ max_num_seqs = 1 for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) -@@ -479,11 +831,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -479,11 +839,14 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=None, @@ -16765,7 +17736,7 @@ index 9d49b4385..7396b0c89 100644 finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) -@@ -493,25 +848,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -493,25 +856,39 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): batch_size=batch_size, dtype=self.model_config.dtype, device=self.device) @@ -16816,7 +17787,7 @@ index 9d49b4385..7396b0c89 100644 """Helper method to prepare the model input based on a given sequence group. Prepares metadata needed for the base model forward pass but not metadata for possible additional steps, e.g., sampling. -@@ -524,6 +893,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -524,6 +901,22 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): return builder.build() # type: ignore @@ -16839,7 +17810,7 @@ index 9d49b4385..7396b0c89 100644 def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], -@@ -563,6 +948,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -563,6 +956,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): raise ValueError( "XPUModelRunner does not support multi-step execution.") @@ -16852,7 +17823,7 @@ index 9d49b4385..7396b0c89 100644 model_executable = self.model if (self.observability_config is not None and self.observability_config.collect_model_forward_time): -@@ -612,3 +1003,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): +@@ -612,3 +1011,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): output.model_forward_time = model_forward_time return [output] @@ -17898,7 +18869,7 @@ index 000000000..550bf81e8 + + return pooling_metadata diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py -index 3aea0d741..060bb10ab 100644 +index 3aea0d741..7421826c3 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -2,9 +2,10 @@ @@ -17913,31 +18884,35 @@ index 3aea0d741..060bb10ab 100644 import oneccl_bindings_for_pytorch # noqa: F401 import torch import torch.distributed -@@ -19,7 +20,8 @@ from vllm.platforms import current_platform +@@ -19,7 +20,10 @@ from vllm.platforms import current_platform from vllm.worker.cache_engine import CacheEngine from vllm.worker.worker import Worker from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase -from vllm.worker.xpu_model_runner import XPUModelRunner +from vllm.worker.xpu_model_runner import XPUModelRunner, XPUModelRunnerBase +from vllm.worker.xpu_pooling_model_runner import XPUPoolingModelRunner ++from vllm.worker.xpu_enc_dec_model_runner import XPUEncoderDecoderModelRunner ++ logger = init_logger(__name__) -@@ -56,8 +58,12 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): +@@ -56,8 +60,14 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): if parallel_config and is_driver_worker: assert rank % parallel_config.tensor_parallel_size == 0, \ "Driver worker should be rank 0 of tensor parallel group." + ModelRunnerClass: Type[XPUModelRunnerBase] = XPUModelRunner + model_config = self.model_config -+ if model_config.task == "embed": ++ if model_config.task == "embed" or model_config.task == "score": + ModelRunnerClass = XPUPoolingModelRunner ++ elif model_config.is_encoder_decoder: ++ ModelRunnerClass = XPUEncoderDecoderModelRunner - self.model_runner = XPUModelRunner( # type: ignore + self.model_runner = ModelRunnerClass( # type: ignore vllm_config=vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, -@@ -65,7 +71,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): +@@ -65,7 +75,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] @@ -17946,15 +18921,75 @@ index 3aea0d741..060bb10ab 100644 def init_device(self) -> None: if self.device_config.device.type == "xpu" and current_platform.is_xpu( -@@ -100,6 +106,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): +@@ -99,16 +109,74 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): + """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. ++ flag = int(os.getenv("IPEX_LLM_FIND_MAX_LENGTH", -1)) ++ if flag != -1: ++ assert flag > 0 ++ torch.xpu.empty_cache() ++ before_memory = torch.xpu.memory_reserved() ++ max_num_batched_tokens = flag ++ max_num_seqs = 1 ++ support_input = [] ++ support_kv_cache = [] ++ while True: ++ print(f"Profiling with max_num_batched_tokens {max_num_batched_tokens}...") ++ self.model_runner.profile_run(max_num_batched_tokens, max_num_seqs) ++ torch.xpu.synchronize() ++ used_memory = torch.xpu.memory_reserved() ++ total_gpu_memory = torch.xpu.get_device_properties( ++ self.local_rank).total_memory ++ free_gpu_memory = total_gpu_memory - used_memory ++ peak_memory = self.init_gpu_memory - free_gpu_memory ++ assert peak_memory > 0 ++ cache_block_size = self.get_cache_block_size_bytes() ++ num_gpu_blocks = int( ++ (total_gpu_memory * self.cache_config.gpu_memory_utilization - ++ peak_memory) // cache_block_size) ++ num_cpu_blocks = int(self.cache_config.swap_space_bytes // ++ cache_block_size) ++ num_gpu_blocks = max(num_gpu_blocks, 0) ++ num_cpu_blocks = max(num_cpu_blocks, 0) ++ gc.collect() ++ torch.xpu.empty_cache() ++ # Begin to handle data... ++ if num_gpu_blocks == 0: ++ break ++ kv_cache_support_length = num_gpu_blocks * self.cache_config.block_size ++ # Too long input... ++ if max_num_batched_tokens > kv_cache_support_length: ++ break ++ support_input.append(max_num_batched_tokens) ++ support_kv_cache.append(kv_cache_support_length) ++ max_num_batched_tokens += 250 ++ ++ print(f"Recommended max input length: {support_input[len(support_input) - 1]}") ++ print(f"{'input length':<15} {'kv cache length':<15}") ++ print("-" * 30) ++ ++ for inp, kv in zip(support_input, support_kv_cache): ++ print(f"{inp:<15} {kv:<15}") torch.xpu.empty_cache() + before_memory = torch.xpu.memory_reserved() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. -@@ -108,7 +115,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): +- self.model_runner.profile_run() ++ self_max_num_batched_tokens = os.getenv("IPEX_LLM_SELF_MAX_NUM_BATCHED_TOKENS", None) ++ if self_max_num_batched_tokens is not None: ++ # If this get set, then profile using max input length ++ max_num_batched_tokens = int(self_max_num_batched_tokens) ++ self_max_num_seqs = os.getenv("IPEX_LLM_SELF_MAX_NUM_SEQS", None) ++ if self_max_num_seqs is not None: ++ max_num_seqs = int(self_max_num_seqs) ++ else: ++ max_num_seqs = 1 ++ self.model_runner.profile_run(max_num_batched_tokens, max_num_seqs) ++ else: ++ self.model_runner.profile_run() + # Calculate the number of blocks that can be allocated with the # profiled peak memory. torch.xpu.synchronize() @@ -17963,7 +18998,7 @@ index 3aea0d741..060bb10ab 100644 total_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory free_gpu_memory = total_gpu_memory - used_memory -@@ -132,6 +139,20 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): +@@ -132,6 +200,20 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): num_cpu_blocks = max(num_cpu_blocks, 0) gc.collect() torch.xpu.empty_cache() @@ -17984,7 +19019,7 @@ index 3aea0d741..060bb10ab 100644 return num_gpu_blocks, num_cpu_blocks def _warm_up_model(self) -> None: -@@ -177,9 +198,10 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): +@@ -177,9 +259,10 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker): parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) # global all_reduce needed for overall oneccl warm up