fix qwen2 vl again (#12174)

This commit is contained in:
Yishuo Wang 2024-10-10 13:50:01 +08:00 committed by GitHub
parent aef1f671bd
commit 535bee5381
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 0 deletions

View file

@ -1691,11 +1691,13 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_get_dtype
from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward
from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward
convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward)
convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward)
model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual)
convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward) convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward)
convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward) convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward)
convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward) convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward)

View file

@ -175,6 +175,10 @@ def qwen2_vl_model_forward(
) )
def qwen2_vision_get_dtype(self) -> torch.dtype:
return self.patch_embed.proj.weight.dtype
def qwen2_vision_attention_forward( def qwen2_vision_attention_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,