Fix Gemma 2 on LNL (#12240)

* Fix gemma 2 on LNL

* Python style fix
This commit is contained in:
Yuwen Hu 2024-10-21 18:25:53 +08:00 committed by GitHub
parent ac2dac857c
commit b3df47486d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1867,10 +1867,12 @@ def _optimize_post(model, lightweight_bmm=False):
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward
from ipex_llm.transformers.models.gemma2 import gemma2_model_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention, \
Gemma2SdpaAttention
from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP
convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward)
convert_forward(model, Gemma2Attention, gemma2_attention_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward)
convert_forward(model, Gemma2SdpaAttention, gemma2_attention_forward)
convert_forward(model, Gemma2Model, gemma2_model_forward) convert_forward(model, Gemma2Model, gemma2_model_forward)
convert_forward(model, Gemma2MLP, mlp_gelu_forward) convert_forward(model, Gemma2MLP, mlp_gelu_forward)
elif model.config.model_type == "Yi": elif model.config.model_type == "Yi":