optimize qwen2 vl perf again (#12167)

This commit is contained in:
Yishuo Wang 2024-10-09 16:43:48 +08:00 committed by GitHub
parent 412cf8e20c
commit 78d253165d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 35 additions and 0 deletions

View file

@ -1691,10 +1691,12 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
elif model.config.model_type == "cohere":

View file

@ -50,6 +50,7 @@ from ipex_llm.utils.common import invalidInputError
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache
@ -174,6 +175,38 @@ def qwen2_vl_model_forward(
)
def qwen2_vision_attention_forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1
).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights, False)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
def qwen2_vl_attention_forward(
self,
hidden_states: torch.Tensor,