diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 00b5e001..8e5cf426 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1293,6 +1293,7 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.mllama import mllama_vision_attention_forward convert_forward(model, module.MllamaVisionAttention, mllama_vision_attention_forward) + convert_forward(model, module.MllamaVisionSdpaAttention, mllama_vision_attention_forward) from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.common import mlp_silu_forward @@ -1303,7 +1304,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MllamaTextMLP, mlp_silu_forward) convert_forward(model, module.MllamaTextModel, mllama_text_model_forward) convert_forward(model, module.MllamaTextSelfAttention, llama_attention_forward) + convert_forward(model, module.MllamaTextSelfSdpaAttention, llama_attention_forward) convert_forward(model, module.MllamaTextCrossAttention, mllama_cross_attention_forward) + convert_forward(model, module.MllamaTextCrossSdpaAttention, mllama_cross_attention_forward) elif model.config.model_type == "llama": from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaMLP