From b3df47486d786f570e4c85bee5908c4d90ba88e1 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Mon, 21 Oct 2024 18:25:53 +0800 Subject: [PATCH] Fix Gemma 2 on LNL (#12240) * Fix gemma 2 on LNL * Python style fix --- python/llm/src/ipex_llm/transformers/convert.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 17c7978e..fce3b7f7 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1867,10 +1867,12 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward - from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention, \ + Gemma2SdpaAttention from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward) + convert_forward(model, Gemma2SdpaAttention, gemma2_attention_forward) convert_forward(model, Gemma2Model, gemma2_model_forward) convert_forward(model, Gemma2MLP, mlp_gelu_forward) elif model.config.model_type == "Yi":