diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 8a992f75..5d9d4746 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1491,6 +1491,13 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.GemmaMLP, gemma_mlp_forward) + elif model.config.model_type == "gemma2": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward + convert_forward(model, + module.GemmaRMSNorm, + gemma_rms_norm_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)