Fix llama 3.2 vision on LNL (#12264)

* Fix llama 3.2 vision on LNL

* Small fix
This commit is contained in:
Yuwen Hu 2024-10-25 16:23:31 +08:00 committed by GitHub
parent 94c4568988
commit 43b25a2fe7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1293,6 +1293,7 @@ def _optimize_post(model, lightweight_bmm=False):
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.mllama import mllama_vision_attention_forward
convert_forward(model, module.MllamaVisionAttention, mllama_vision_attention_forward)
convert_forward(model, module.MllamaVisionSdpaAttention, mllama_vision_attention_forward)
from ipex_llm.transformers.models.common import rms_norm_forward
from ipex_llm.transformers.models.common import mlp_silu_forward
@ -1303,7 +1304,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model, module.MllamaTextMLP, mlp_silu_forward)
convert_forward(model, module.MllamaTextModel, mllama_text_model_forward)
convert_forward(model, module.MllamaTextSelfAttention, llama_attention_forward)
convert_forward(model, module.MllamaTextSelfSdpaAttention, llama_attention_forward)
convert_forward(model, module.MllamaTextCrossAttention, mllama_cross_attention_forward)
convert_forward(model, module.MllamaTextCrossSdpaAttention, mllama_cross_attention_forward)
elif model.config.model_type == "llama":
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaMLP