optimize gemma2 rmsnorm (#11500)

This commit is contained in:
Xin Qiu 2024-07-03 15:21:03 +08:00 committed by GitHub
parent 61c36ba085
commit f84ca99b9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1491,6 +1491,13 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.GemmaMLP,
gemma_mlp_forward)
elif model.config.model_type == "gemma2":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
convert_forward(model,
module.GemmaRMSNorm,
gemma_rms_norm_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)