From 9b81236a2ed7007ac999591e95f62ea7f8e9f8a1 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 15 Oct 2024 15:54:25 +0800 Subject: [PATCH] optimzie qwen2-vl vision (#12203) --- .../ipex_llm/transformers/models/internvl.py | 5 ++- .../ipex_llm/transformers/models/qwen2_vl.py | 38 +++++++++++++------ .../src/ipex_llm/transformers/models/utils.py | 10 ++++- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/internvl.py b/python/llm/src/ipex_llm/transformers/models/internvl.py index 1cecdd25..43ce6f56 100644 --- a/python/llm/src/ipex_llm/transformers/models/internvl.py +++ b/python/llm/src/ipex_llm/transformers/models/internvl.py @@ -26,6 +26,7 @@ import torch from ipex_llm.utils.common.log4Error import invalidInputError +from ipex_llm.transformers.models.utils import use_sdp_non_causal def _get_pos_embed(self, pos_embed, H, W): @@ -175,9 +176,9 @@ def intern_attention_forward(self, x: torch.Tensor) -> torch.Tensor: q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - if x.device.type == "xpu": + if use_sdp_non_causal(self.head_dim, q.device, q.dtype): import xe_addons - x = xe_addons.sdp_non_causal(q.contiguous(), k.contiguous(), v.contiguous(), None) + x = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), None) else: attn = ((q * self.scale) @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index ac800291..dd0e0de3 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -45,6 +45,7 @@ import torch from ipex_llm.transformers.models.common import merge_qkv_base, attention_softmax from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, should_use_fuse_rope +from ipex_llm.transformers.models.utils import use_sdp_non_causal from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.utils.common import invalidInputError @@ -191,20 +192,35 @@ def qwen2_vision_attention_forward( q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) - attention_mask = torch.full( - [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype - ) - for i in range(1, len(cu_seqlens)): - attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], - cu_seqlens[i - 1]:cu_seqlens[i]] = 0 - q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) - attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights, False) - attn_output = torch.matmul(attn_weights, v) + + if len(cu_seqlens) == 2 and cu_seqlens.tolist() == [0, seq_length]: + attention_mask = None + else: + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = 0 + + if use_sdp_non_causal(self.head_dim, q.device, q.dtype): + import xe_addons + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(0) + attn_output = xe_addons.sdp_non_causal(q, k.contiguous(), v.contiguous(), attention_mask) + attn_output = attn_output.squeeze(0) + else: + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = attention_softmax(attn_weights, False) + attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 1f14bf2f..e4a87266 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -338,12 +338,20 @@ def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): return ( q_len == kv_len # first token and head_dim in [-1, 64, 80, 96, 128] # for now - and query_states.device.type == "xpu" # GPU + and query_states.device.type == "xpu" # GPU and query_states.dtype in [torch.float, torch.half] # fp32/fp16 and not query_states.requires_grad and not training # not training ) +def use_sdp_non_causal(head_dim, device, dtype): + return ( + head_dim in [40, 64, 80] + and device.type == "xpu" # GPU + and dtype in [torch.float, torch.half] # fp32/fp16 + ) + + def mlp_fusion_check(x, qtype, training): invalidInputError(x.dim() == 2, "Here input x's dim should be 2.")