optimize gemma2 rmsnorm (#11500)
This commit is contained in:
parent
61c36ba085
commit
f84ca99b9f
1 changed files with 7 additions and 0 deletions
|
|
@ -1491,6 +1491,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.GemmaMLP,
|
||||
gemma_mlp_forward)
|
||||
elif model.config.model_type == "gemma2":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward
|
||||
convert_forward(model,
|
||||
module.GemmaRMSNorm,
|
||||
gemma_rms_norm_forward)
|
||||
elif model.config.model_type == "Yi":
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue