From f84ca99b9f159bdadfd2e99bf6ca2585dfcb8f5f Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Wed, 3 Jul 2024 15:21:03 +0800 Subject: [PATCH] optimize gemma2 rmsnorm (#11500) --- python/llm/src/ipex_llm/transformers/convert.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 8a992f75..5d9d4746 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1491,6 +1491,13 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.GemmaMLP, gemma_mlp_forward) + elif model.config.model_type == "gemma2": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward + convert_forward(model, + module.GemmaRMSNorm, + gemma_rms_norm_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name)