update patches (#13290)
Signed-off-by: liu-shaojun <shaojun.liu@intel.com>
This commit is contained in:
		
							parent
							
								
									9cfdf143a2
								
							
						
					
					
						commit
						cac90a9238
					
				
					 1 changed files with 79 additions and 49 deletions
				
			
		| 
						 | 
					@ -9200,7 +9200,7 @@ index c3d210c27..6a9c7c798 100644
 | 
				
			||||||
+                                              max_seq_length, slice_offset,
 | 
					+                                              max_seq_length, slice_offset,
 | 
				
			||||||
+                                              slice_size, add_inputs)
 | 
					+                                              slice_size, add_inputs)
 | 
				
			||||||
diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py
 | 
					diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py
 | 
				
			||||||
index d3c61ea26..7d7adad15 100644
 | 
					index d3c61ea26..8e1219180 100644
 | 
				
			||||||
--- a/vllm/attention/backends/ipex_attn.py
 | 
					--- a/vllm/attention/backends/ipex_attn.py
 | 
				
			||||||
+++ b/vllm/attention/backends/ipex_attn.py
 | 
					+++ b/vllm/attention/backends/ipex_attn.py
 | 
				
			||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
 | 
					@@ -5,7 +5,7 @@ from dataclasses import dataclass
 | 
				
			||||||
| 
						 | 
					@ -9723,7 +9723,7 @@ index d3c61ea26..7d7adad15 100644
 | 
				
			||||||
     def forward(
 | 
					     def forward(
 | 
				
			||||||
         self,
 | 
					         self,
 | 
				
			||||||
         layer: AttentionLayer,
 | 
					         layer: AttentionLayer,
 | 
				
			||||||
@@ -195,84 +566,224 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
 | 
					@@ -195,84 +566,236 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
 | 
				
			||||||
         Returns:
 | 
					         Returns:
 | 
				
			||||||
             shape = [num_tokens, num_heads * head_size]
 | 
					             shape = [num_tokens, num_heads * head_size]
 | 
				
			||||||
         """
 | 
					         """
 | 
				
			||||||
| 
						 | 
					@ -9948,6 +9948,17 @@ index d3c61ea26..7d7adad15 100644
 | 
				
			||||||
+                                    self.logits_soft_cap,
 | 
					+                                    self.logits_soft_cap,
 | 
				
			||||||
+                                    self.scale).squeeze(0).movedim(
 | 
					+                                    self.scale).squeeze(0).movedim(
 | 
				
			||||||
+                                        query.dim() - 2, 0)
 | 
					+                                        query.dim() - 2, 0)
 | 
				
			||||||
 | 
					+                        elif attn_type != AttentionType.DECODER:
 | 
				
			||||||
 | 
					+                            import xe_addons
 | 
				
			||||||
 | 
					+                            if mask is not None:
 | 
				
			||||||
 | 
					+                                mask = mask.unsqueeze(0)
 | 
				
			||||||
 | 
					+                            sub_out = xe_addons.sdp_non_causal(
 | 
				
			||||||
 | 
					+                                query[None, :, start_q:end_q, :].contiguous(),
 | 
				
			||||||
 | 
					+                                key[None, :, start_kv:end_kv, :].contiguous(),
 | 
				
			||||||
 | 
					+                                value[None, :, start_kv:end_kv, :].contiguous(),
 | 
				
			||||||
 | 
					+                                mask,
 | 
				
			||||||
 | 
					+                                scale).squeeze(0).movedim(
 | 
				
			||||||
 | 
					+                                    query.dim() - 2, 0)                            
 | 
				
			||||||
+                        else:
 | 
					+                        else:
 | 
				
			||||||
+                            sub_out = torch.nn.functional.scaled_dot_product_attention(
 | 
					+                            sub_out = torch.nn.functional.scaled_dot_product_attention(
 | 
				
			||||||
+                                query[None, :, start_q:end_q, :],
 | 
					+                                query[None, :, start_q:end_q, :],
 | 
				
			||||||
| 
						 | 
					@ -9965,6 +9976,7 @@ index d3c61ea26..7d7adad15 100644
 | 
				
			||||||
                 # prefix-enabled attention
 | 
					                 # prefix-enabled attention
 | 
				
			||||||
-                raise RuntimeError(
 | 
					-                raise RuntimeError(
 | 
				
			||||||
-                    "IPEX backend doesn't support prefix decoding.")
 | 
					-                    "IPEX backend doesn't support prefix decoding.")
 | 
				
			||||||
 | 
					+                query = query[:num_prefill_tokens]
 | 
				
			||||||
+                if self.num_kv_heads != self.num_heads:
 | 
					+                if self.num_kv_heads != self.num_heads:
 | 
				
			||||||
+                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
 | 
					+                    key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
 | 
				
			||||||
+                    value = value.repeat_interleave(self.num_queries_per_kv,
 | 
					+                    value = value.repeat_interleave(self.num_queries_per_kv,
 | 
				
			||||||
| 
						 | 
					@ -9981,8 +9993,8 @@ index d3c61ea26..7d7adad15 100644
 | 
				
			||||||
+                    out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item(), torch.amax(query_len).item())
 | 
					+                    out = vllm._C.ops.context_attention_forward_v2(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item(), torch.amax(query_len).item())
 | 
				
			||||||
+                else:
 | 
					+                else:
 | 
				
			||||||
+                    out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item())
 | 
					+                    out = vllm._C.ops.context_attention_forward_v1(query, key_cache, value_cache, prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, prefill_meta.context_lens, prefill_meta.max_seqlen, torch.amax(prefill_meta.context_lens).item())
 | 
				
			||||||
+                assert output[:num_prefill_query_tokens].shape == out.shape
 | 
					+                assert output[:num_prefill_tokens].shape == out.shape
 | 
				
			||||||
+                output[:num_prefill_query_tokens] = out
 | 
					+                output[:num_prefill_tokens] = out
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
-        else:
 | 
					-        else:
 | 
				
			||||||
+        if decode_meta := attn_metadata.decode_metadata:
 | 
					+        if decode_meta := attn_metadata.decode_metadata:
 | 
				
			||||||
| 
						 | 
					@ -10004,7 +10016,7 @@ index d3c61ea26..7d7adad15 100644
 | 
				
			||||||
             # NOTE(woosuk): We use a simple heuristic to decide whether to use
 | 
					             # NOTE(woosuk): We use a simple heuristic to decide whether to use
 | 
				
			||||||
             # PagedAttention V1 or V2. If the number of partitions is 1, we use
 | 
					             # PagedAttention V1 or V2. If the number of partitions is 1, we use
 | 
				
			||||||
             # V1 to avoid the overhead of reduction. Also, if the number of
 | 
					             # V1 to avoid the overhead of reduction. Also, if the number of
 | 
				
			||||||
@@ -281,59 +792,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
 | 
					@@ -281,59 +804,86 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
 | 
				
			||||||
             # TODO(woosuk): Tune this heuristic.
 | 
					             # TODO(woosuk): Tune this heuristic.
 | 
				
			||||||
             # For context len > 8192, use V2 kernel to avoid shared memory
 | 
					             # For context len > 8192, use V2 kernel to avoid shared memory
 | 
				
			||||||
             # shortage.
 | 
					             # shortage.
 | 
				
			||||||
| 
						 | 
					@ -10136,7 +10148,7 @@ index d3c61ea26..7d7adad15 100644
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
             # Reshape the output tensor.
 | 
					             # Reshape the output tensor.
 | 
				
			||||||
         return output.view(-1, self.num_heads * self.head_size)
 | 
					         return output.view(-1, self.num_heads * self.head_size)
 | 
				
			||||||
@@ -386,3 +924,114 @@ def _make_sliding_window_bias(
 | 
					@@ -386,3 +936,114 @@ def _make_sliding_window_bias(
 | 
				
			||||||
         attn_biases.append(mask.to(dtype))
 | 
					         attn_biases.append(mask.to(dtype))
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     return attn_biases
 | 
					     return attn_biases
 | 
				
			||||||
| 
						 | 
					@ -13637,18 +13649,28 @@ index c4d02e5dd..2831a5a12 100644
 | 
				
			||||||
         )
 | 
					         )
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
 | 
					diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
 | 
				
			||||||
index 1e6ff1fec..90ebe5ca9 100644
 | 
					index 1e6ff1fec..efd659075 100644
 | 
				
			||||||
--- a/vllm/model_executor/models/qwen2_5_vl.py
 | 
					--- a/vllm/model_executor/models/qwen2_5_vl.py
 | 
				
			||||||
+++ b/vllm/model_executor/models/qwen2_5_vl.py
 | 
					+++ b/vllm/model_executor/models/qwen2_5_vl.py
 | 
				
			||||||
@@ -302,23 +302,33 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
					@@ -302,23 +302,37 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
				
			||||||
                                       "(b s) ... -> b s ...",
 | 
					                                       "(b s) ... -> b s ...",
 | 
				
			||||||
                                       b=batch_size)
 | 
					                                       b=batch_size)
 | 
				
			||||||
         elif self.attn_backend == _Backend.TORCH_SDPA:
 | 
					         elif self.attn_backend == _Backend.TORCH_SDPA:
 | 
				
			||||||
-            # Execute attention entry by entry for speed & less VRAM.
 | 
					-            # Execute attention entry by entry for speed & less VRAM.
 | 
				
			||||||
-            outputs = []
 | 
					+            # TODO(xiangyu): Maybe add attn_backend xpu?
 | 
				
			||||||
-            for i in range(1, len(cu_seqlens)):
 | 
					             outputs = []
 | 
				
			||||||
-                start_idx = cu_seqlens[i - 1]
 | 
					+            head_dim = q.shape[-1]
 | 
				
			||||||
-                end_idx = cu_seqlens[i]
 | 
					+            import math
 | 
				
			||||||
 | 
					+            import xe_addons
 | 
				
			||||||
 | 
					+            head_dim = q.shape[-1]
 | 
				
			||||||
 | 
					+            scale = 1 / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					+            q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
 | 
				
			||||||
 | 
					+            q = q.contiguous()
 | 
				
			||||||
 | 
					+            k = k.contiguous()
 | 
				
			||||||
 | 
					+            v = v.contiguous()
 | 
				
			||||||
 | 
					             for i in range(1, len(cu_seqlens)):
 | 
				
			||||||
 | 
					                 start_idx = cu_seqlens[i - 1]
 | 
				
			||||||
 | 
					                 end_idx = cu_seqlens[i]
 | 
				
			||||||
-                q_i = q[:, start_idx:end_idx]
 | 
					-                q_i = q[:, start_idx:end_idx]
 | 
				
			||||||
-                k_i = k[:, start_idx:end_idx]
 | 
					-                k_i = k[:, start_idx:end_idx]
 | 
				
			||||||
-                v_i = v[:, start_idx:end_idx]
 | 
					-                v_i = v[:, start_idx:end_idx]
 | 
				
			||||||
| 
						 | 
					@ -13659,39 +13681,28 @@ index 1e6ff1fec..90ebe5ca9 100644
 | 
				
			||||||
-                                                          v_i,
 | 
					-                                                          v_i,
 | 
				
			||||||
-                                                          dropout_p=0.0)
 | 
					-                                                          dropout_p=0.0)
 | 
				
			||||||
-                output_i = rearrange(output_i, "b h s d -> b s h d ")
 | 
					-                output_i = rearrange(output_i, "b h s d -> b s h d ")
 | 
				
			||||||
-                outputs.append(output_i)
 | 
					+                q_i = q[:, :, start_idx:end_idx]
 | 
				
			||||||
 | 
					+                k_i = k[:, :, start_idx:end_idx]
 | 
				
			||||||
 | 
					+                v_i = v[:, :, start_idx:end_idx]
 | 
				
			||||||
 | 
					+                # output_i = F.scaled_dot_product_attention(q_i,
 | 
				
			||||||
 | 
					+                #                                           k_i,
 | 
				
			||||||
 | 
					+                #                                           v_i,
 | 
				
			||||||
 | 
					+                #                                           dropout_p=0.0)
 | 
				
			||||||
 | 
					+                output_i = xe_addons.sdp_non_causal(
 | 
				
			||||||
 | 
					+                    q_i,
 | 
				
			||||||
 | 
					+                    k_i,
 | 
				
			||||||
 | 
					+                    v_i,
 | 
				
			||||||
 | 
					+                    None,
 | 
				
			||||||
 | 
					+                    scale)
 | 
				
			||||||
 | 
					+                # output_i = rearrange(output_i, "b h s d -> b s h d ")
 | 
				
			||||||
 | 
					                 outputs.append(output_i)
 | 
				
			||||||
-            context_layer = torch.cat(outputs, dim=1)
 | 
					-            context_layer = torch.cat(outputs, dim=1)
 | 
				
			||||||
+            # TODO(xiangyu): Maybe add attn_backend xpu?
 | 
					+            context_layer = torch.cat(outputs, dim=2)
 | 
				
			||||||
+            q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
 | 
					+            context_layer = rearrange(context_layer, "b h s d -> b s h d")
 | 
				
			||||||
+            from vllm._ipex_ops import ipex_ops
 | 
					 | 
				
			||||||
+            output = torch.empty(
 | 
					 | 
				
			||||||
+                        (q.shape[0], q.shape[1], q.shape[2]),
 | 
					 | 
				
			||||||
+                        dtype=q.dtype,
 | 
					 | 
				
			||||||
+                        device=q.device)
 | 
					 | 
				
			||||||
+            import math
 | 
					 | 
				
			||||||
+            head_dim = q.shape[-1]
 | 
					 | 
				
			||||||
+            scale = 1 / math.sqrt(head_dim)
 | 
					 | 
				
			||||||
+            ipex_ops.varlen_attention(q, k, v, output,
 | 
					 | 
				
			||||||
+                                    cu_seqlens,
 | 
					 | 
				
			||||||
+                                    cu_seqlens,
 | 
					 | 
				
			||||||
+                                    max_seqlen,
 | 
					 | 
				
			||||||
+                                    max_seqlen,
 | 
					 | 
				
			||||||
+                                    pdropout=0,
 | 
					 | 
				
			||||||
+                                    softmax_scale=scale,
 | 
					 | 
				
			||||||
+                                    zero_tensors=False,
 | 
					 | 
				
			||||||
+                                    is_causal=False,
 | 
					 | 
				
			||||||
+                                    return_softmax=False,
 | 
					 | 
				
			||||||
+                                    gen_=None,
 | 
					 | 
				
			||||||
+                                    logits_soft_cap=0
 | 
					 | 
				
			||||||
+                                    )
 | 
					 | 
				
			||||||
+
 | 
					 | 
				
			||||||
+            context_layer = rearrange(output,
 | 
					 | 
				
			||||||
+                                      "(b s) ... -> b s ...",
 | 
					 | 
				
			||||||
+                                      b=batch_size)
 | 
					 | 
				
			||||||
         elif self.attn_backend == _Backend.XFORMERS:
 | 
					         elif self.attn_backend == _Backend.XFORMERS:
 | 
				
			||||||
             from xformers import ops as xops
 | 
					             from xformers import ops as xops
 | 
				
			||||||
             from xformers.ops.fmha.attn_bias import BlockDiagonalMask
 | 
					             from xformers.ops.fmha.attn_bias import BlockDiagonalMask
 | 
				
			||||||
@@ -613,10 +623,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
 | 
					@@ -613,10 +627,11 @@ class Qwen2_5_VisionTransformer(nn.Module):
 | 
				
			||||||
         cu_seqlens: torch.Tensor,
 | 
					         cu_seqlens: torch.Tensor,
 | 
				
			||||||
     ) -> tuple[Optional[int], Optional[list[int]]]:
 | 
					     ) -> tuple[Optional[int], Optional[list[int]]]:
 | 
				
			||||||
         max_seqlen, seqlens = None, None
 | 
					         max_seqlen, seqlens = None, None
 | 
				
			||||||
| 
						 | 
					@ -13707,7 +13718,7 @@ index 1e6ff1fec..90ebe5ca9 100644
 | 
				
			||||||
         return max_seqlen, seqlens
 | 
					         return max_seqlen, seqlens
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     def forward(
 | 
					     def forward(
 | 
				
			||||||
@@ -1082,7 +1093,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
 | 
					@@ -1082,7 +1097,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
 | 
				
			||||||
                     image_input=image_input,
 | 
					                     image_input=image_input,
 | 
				
			||||||
                     video_input=video_input)
 | 
					                     video_input=video_input)
 | 
				
			||||||
                 input_ids = None
 | 
					                 input_ids = None
 | 
				
			||||||
| 
						 | 
					@ -17191,7 +17202,7 @@ index 000000000..dffc7b367
 | 
				
			||||||
+        return (attn_metadata, encoder_input_tokens_tensor,
 | 
					+        return (attn_metadata, encoder_input_tokens_tensor,
 | 
				
			||||||
+                encoder_input_positions_tensor)
 | 
					+                encoder_input_positions_tensor)
 | 
				
			||||||
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
 | 
					diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
index 9d49b4385..dc5e95f4e 100644
 | 
					index 9d49b4385..065a09614 100644
 | 
				
			||||||
--- a/vllm/worker/xpu_model_runner.py
 | 
					--- a/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
+++ b/vllm/worker/xpu_model_runner.py
 | 
					+++ b/vllm/worker/xpu_model_runner.py
 | 
				
			||||||
@@ -5,8 +5,8 @@ import time
 | 
					@@ -5,8 +5,8 @@ import time
 | 
				
			||||||
| 
						 | 
					@ -17526,7 +17537,7 @@ index 9d49b4385..dc5e95f4e 100644
 | 
				
			||||||
+        slot_mapping_tensor = torch.tensor(slot_mapping,
 | 
					+        slot_mapping_tensor = torch.tensor(slot_mapping,
 | 
				
			||||||
+                                           dtype=torch.long,
 | 
					+                                           dtype=torch.long,
 | 
				
			||||||
+                                           device=self.device)
 | 
					+                                           device=self.device)
 | 
				
			||||||
+        if need_block_table or "bge" in self.runner.model_config.model.lower():
 | 
					+        if need_block_table or "reranker" in self.runner.model_config.model.lower() or "embedding" in self.runner.model_config.model.lower():
 | 
				
			||||||
+            seq_lens_tensor = torch.tensor(seq_lens,
 | 
					+            seq_lens_tensor = torch.tensor(seq_lens,
 | 
				
			||||||
+                                        dtype=torch.int,
 | 
					+                                        dtype=torch.int,
 | 
				
			||||||
+                                        device=self.device)
 | 
					+                                        device=self.device)
 | 
				
			||||||
| 
						 | 
					@ -18998,7 +19009,7 @@ index 000000000..550bf81e8
 | 
				
			||||||
+
 | 
					+
 | 
				
			||||||
+        return pooling_metadata
 | 
					+        return pooling_metadata
 | 
				
			||||||
diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py
 | 
					diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py
 | 
				
			||||||
index 3aea0d741..7421826c3 100644
 | 
					index 3aea0d741..90ee8e2f6 100644
 | 
				
			||||||
--- a/vllm/worker/xpu_worker.py
 | 
					--- a/vllm/worker/xpu_worker.py
 | 
				
			||||||
+++ b/vllm/worker/xpu_worker.py
 | 
					+++ b/vllm/worker/xpu_worker.py
 | 
				
			||||||
@@ -2,9 +2,10 @@
 | 
					@@ -2,9 +2,10 @@
 | 
				
			||||||
| 
						 | 
					@ -19041,16 +19052,35 @@ index 3aea0d741..7421826c3 100644
 | 
				
			||||||
             vllm_config=vllm_config,
 | 
					             vllm_config=vllm_config,
 | 
				
			||||||
             kv_cache_dtype=self.cache_config.cache_dtype,
 | 
					             kv_cache_dtype=self.cache_config.cache_dtype,
 | 
				
			||||||
             is_driver_worker=is_driver_worker,
 | 
					             is_driver_worker=is_driver_worker,
 | 
				
			||||||
@@ -65,7 +75,7 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
					@@ -65,7 +75,26 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
				
			||||||
         # Uninitialized cache engine. Will be initialized by
 | 
					         # Uninitialized cache engine. Will be initialized by
 | 
				
			||||||
         # initialize_cache.
 | 
					         # initialize_cache.
 | 
				
			||||||
         self.cache_engine: List[CacheEngine]
 | 
					         self.cache_engine: List[CacheEngine]
 | 
				
			||||||
-        self.gpu_cache: Optional[List[List[torch.Tensor]]]
 | 
					-        self.gpu_cache: Optional[List[List[torch.Tensor]]]
 | 
				
			||||||
+        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
 | 
					+        self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 | 
					+        # Torch profiler. Enabled and configured through env vars:
 | 
				
			||||||
 | 
					+        # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
 | 
				
			||||||
 | 
					+        import vllm.envs as envs
 | 
				
			||||||
 | 
					+        if envs.VLLM_TORCH_PROFILER_DIR:
 | 
				
			||||||
 | 
					+            torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
 | 
				
			||||||
 | 
					+            logger.info("Profiling enabled. Traces will be saved to: %s",
 | 
				
			||||||
 | 
					+                        torch_profiler_trace_dir)
 | 
				
			||||||
 | 
					+            self.profiler = torch.profiler.profile(
 | 
				
			||||||
 | 
					+                activities=[
 | 
				
			||||||
 | 
					+                    torch.profiler.ProfilerActivity.CPU,
 | 
				
			||||||
 | 
					+                    torch.profiler.ProfilerActivity.XPU,
 | 
				
			||||||
 | 
					+                ],
 | 
				
			||||||
 | 
					+                with_stack=True,
 | 
				
			||||||
 | 
					+                on_trace_ready=torch.profiler.tensorboard_trace_handler(
 | 
				
			||||||
 | 
					+                    torch_profiler_trace_dir, use_gzip=True))
 | 
				
			||||||
 | 
					+        else:
 | 
				
			||||||
 | 
					+            self.profiler = None
 | 
				
			||||||
 | 
					+
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     def init_device(self) -> None:
 | 
					     def init_device(self) -> None:
 | 
				
			||||||
         if self.device_config.device.type == "xpu" and current_platform.is_xpu(
 | 
					         if self.device_config.device.type == "xpu" and current_platform.is_xpu(
 | 
				
			||||||
@@ -99,16 +109,74 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
					@@ -99,16 +128,74 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
				
			||||||
         """
 | 
					         """
 | 
				
			||||||
         # Profile the memory usage of the model and get the maximum number of
 | 
					         # Profile the memory usage of the model and get the maximum number of
 | 
				
			||||||
         # cache blocks that can be allocated with the remaining free memory.
 | 
					         # cache blocks that can be allocated with the remaining free memory.
 | 
				
			||||||
| 
						 | 
					@ -19127,7 +19157,7 @@ index 3aea0d741..7421826c3 100644
 | 
				
			||||||
         total_gpu_memory = torch.xpu.get_device_properties(
 | 
					         total_gpu_memory = torch.xpu.get_device_properties(
 | 
				
			||||||
             self.local_rank).total_memory
 | 
					             self.local_rank).total_memory
 | 
				
			||||||
         free_gpu_memory = total_gpu_memory - used_memory
 | 
					         free_gpu_memory = total_gpu_memory - used_memory
 | 
				
			||||||
@@ -132,6 +200,20 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
					@@ -132,6 +219,20 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
				
			||||||
         num_cpu_blocks = max(num_cpu_blocks, 0)
 | 
					         num_cpu_blocks = max(num_cpu_blocks, 0)
 | 
				
			||||||
         gc.collect()
 | 
					         gc.collect()
 | 
				
			||||||
         torch.xpu.empty_cache()
 | 
					         torch.xpu.empty_cache()
 | 
				
			||||||
| 
						 | 
					@ -19148,7 +19178,7 @@ index 3aea0d741..7421826c3 100644
 | 
				
			||||||
         return num_gpu_blocks, num_cpu_blocks
 | 
					         return num_gpu_blocks, num_cpu_blocks
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
     def _warm_up_model(self) -> None:
 | 
					     def _warm_up_model(self) -> None:
 | 
				
			||||||
@@ -177,9 +259,10 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
					@@ -177,9 +278,10 @@ class XPUWorker(LoRANotSupportedWorkerBase, Worker):
 | 
				
			||||||
             parallel_config.tensor_parallel_size,
 | 
					             parallel_config.tensor_parallel_size,
 | 
				
			||||||
             parallel_config.pipeline_parallel_size)
 | 
					             parallel_config.pipeline_parallel_size)
 | 
				
			||||||
         # global all_reduce needed for overall oneccl warm up
 | 
					         # global all_reduce needed for overall oneccl warm up
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue