diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 76d13ccb..b7899d73 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1691,11 +1691,13 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.qwen2 import qwen2_mlp_forward + from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_get_dtype from ipex_llm.transformers.models.qwen2_vl import qwen2_vision_attention_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_model_forward from ipex_llm.transformers.models.qwen2_vl import qwen2_vl_attention_forward convert_forward(model, module.Qwen2RMSNorm, rms_norm_forward) convert_forward(model, module.Qwen2MLP, qwen2_mlp_forward) + model.visual.get_dtype = MethodType(qwen2_vision_get_dtype, model.visual) convert_forward(model, module.VisionAttention, qwen2_vision_attention_forward) convert_forward(model, module.Qwen2VLModel, qwen2_vl_model_forward) convert_forward(model, module.Qwen2VLAttention, qwen2_vl_attention_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index 81a07ced..ac800291 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -175,6 +175,10 @@ def qwen2_vl_model_forward( ) +def qwen2_vision_get_dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + def qwen2_vision_attention_forward( self, hidden_states: torch.Tensor,