From 43b25a2fe7f8ca0a6ddfbae1e3bd4bd73d4722f0 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Fri, 25 Oct 2024 16:23:31 +0800 Subject: [PATCH] Fix llama 3.2 vision on LNL (#12264) * Fix llama 3.2 vision on LNL * Small fix --- python/llm/src/ipex_llm/transformers/convert.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 00b5e001..8e5cf426 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1293,6 +1293,7 @@ def _optimize_post(model, lightweight_bmm=False): module = importlib.import_module(modeling_module_name) from ipex_llm.transformers.models.mllama import mllama_vision_attention_forward convert_forward(model, module.MllamaVisionAttention, mllama_vision_attention_forward) + convert_forward(model, module.MllamaVisionSdpaAttention, mllama_vision_attention_forward) from ipex_llm.transformers.models.common import rms_norm_forward from ipex_llm.transformers.models.common import mlp_silu_forward @@ -1303,7 +1304,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MllamaTextMLP, mlp_silu_forward) convert_forward(model, module.MllamaTextModel, mllama_text_model_forward) convert_forward(model, module.MllamaTextSelfAttention, llama_attention_forward) + convert_forward(model, module.MllamaTextSelfSdpaAttention, llama_attention_forward) convert_forward(model, module.MllamaTextCrossAttention, mllama_cross_attention_forward) + convert_forward(model, module.MllamaTextCrossSdpaAttention, mllama_cross_attention_forward) elif model.config.model_type == "llama": from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaMLP