diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 17c7978e..fce3b7f7 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1867,10 +1867,12 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward - from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention, \ + Gemma2SdpaAttention from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward) + convert_forward(model, Gemma2SdpaAttention, gemma2_attention_forward) convert_forward(model, Gemma2Model, gemma2_model_forward) convert_forward(model, Gemma2MLP, mlp_gelu_forward) elif model.config.model_type == "Yi":