From 78d253165dc0d61054164175141c605d5026cce7 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 9 Oct 2024 16:43:48 +0800 Subject: [PATCH] optimize qwen2 vl perf again (#12167) --- .../llm/src/ipex_llm/transformers/convert.py | 2 ++ .../ipex_llm/transformers/models/qwen2_vl.py | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 0d25213a..76d13ccb 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1691,10 +1691,12 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward + from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) + convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward) convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward) convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward) elif model.config.model_type == "cohere": diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index 4efeb1b3..81a07ced 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -50,6 +50,7 @@ from ipex_llm.utils.common import invalidInputError from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb +from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.cache_utils import Cache @@ -174,6 +175,38 @@ def qwen2_vl_model_forward( ) +def qwen2_vision_attention_forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None +) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1 + ).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + attention_mask = torch.full( + [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype + ) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = 0 + + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask + attn_weights = attention_softmax(attn_weights, False) + attn_output = torch.matmul(attn_weights, v) + attn_output = attn_output.transpose(0, 1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + return attn_output + + def qwen2_vl_attention_forward( self, hidden_states: torch.Tensor,