add fused rms norm for Yi and Qwen (#9640)

This commit is contained in:
Xin Qiu 2023-12-08 16:04:38 +08:00 committed by GitHub
parent 5636b0ba80
commit 0b6f29a7fc

View file

@ -580,10 +580,14 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
convert_forward(model,
module.QWenAttention,
qwen_attention_forward
)
convert_forward(model,
module.RMSNorm,
chatglm_rms_norm_forward)
elif model.config.model_type == "aquila":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
@ -606,6 +610,12 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.MistralRMSNorm,
llama_rms_norm_forward)
elif model.config.model_type == "Yi":
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
convert_forward(model,
module.YiRMSNorm,
llama_rms_norm_forward)
elif model.config.model_type == "whisper" and lightweight_bmm:
if platform.system().lower() == 'windows':
from bigdl.llm.transformers.bmm import SafeBMM