optimize qwen2 vl perf again (#12167)
This commit is contained in:
parent
412cf8e20c
commit
78d253165d
2 changed files with 35 additions and 0 deletions
|
|
@ -1691,10 +1691,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from ipex_llm.transformers.models.common import rms_norm_forward
|
from ipex_llm.transformers.models.common import rms_norm_forward
|
||||||
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
|
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
|
||||||
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward
|
||||||
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
|
||||||
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
|
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
|
||||||
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
|
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
|
||||||
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
|
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
|
||||||
|
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
|
||||||
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
|
||||||
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)
|
||||||
elif model.config.model_type == "cohere":
|
elif model.config.model_type == "cohere":
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLAttention
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb
|
||||||
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision
|
||||||
from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import repeat_kv
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
|
@ -174,6 +175,38 @@ def qwen2_vl_model_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def qwen2_vision_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cu_seqlens: torch.Tensor,
|
||||||
|
rotary_pos_emb: torch.Tensor = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_length = hidden_states.shape[0]
|
||||||
|
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1
|
||||||
|
).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||||
|
|
||||||
|
attention_mask = torch.full(
|
||||||
|
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
||||||
|
)
|
||||||
|
for i in range(1, len(cu_seqlens)):
|
||||||
|
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||||
|
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
|
||||||
|
|
||||||
|
q = q.transpose(0, 1)
|
||||||
|
k = k.transpose(0, 1)
|
||||||
|
v = v.transpose(0, 1)
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
attn_weights = attention_softmax(attn_weights, False)
|
||||||
|
attn_output = torch.matmul(attn_weights, v)
|
||||||
|
attn_output = attn_output.transpose(0, 1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def qwen2_vl_attention_forward(
|
def qwen2_vl_attention_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue