add fused rms norm for Yi and Qwen (#9640)
This commit is contained in:
parent
5636b0ba80
commit
0b6f29a7fc
1 changed files with 10 additions and 0 deletions
|
|
@ -580,10 +580,14 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
from bigdl.llm.transformers.models.qwen import qwen_attention_forward
|
||||||
|
from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.QWenAttention,
|
module.QWenAttention,
|
||||||
qwen_attention_forward
|
qwen_attention_forward
|
||||||
)
|
)
|
||||||
|
convert_forward(model,
|
||||||
|
module.RMSNorm,
|
||||||
|
chatglm_rms_norm_forward)
|
||||||
elif model.config.model_type == "aquila":
|
elif model.config.model_type == "aquila":
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
|
@ -606,6 +610,12 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MistralRMSNorm,
|
module.MistralRMSNorm,
|
||||||
llama_rms_norm_forward)
|
llama_rms_norm_forward)
|
||||||
|
elif model.config.model_type == "Yi":
|
||||||
|
modeling_module_name = model.__class__.__module__
|
||||||
|
module = importlib.import_module(modeling_module_name)
|
||||||
|
convert_forward(model,
|
||||||
|
module.YiRMSNorm,
|
||||||
|
llama_rms_norm_forward)
|
||||||
elif model.config.model_type == "whisper" and lightweight_bmm:
|
elif model.config.model_type == "whisper" and lightweight_bmm:
|
||||||
if platform.system().lower() == 'windows':
|
if platform.system().lower() == 'windows':
|
||||||
from bigdl.llm.transformers.bmm import SafeBMM
|
from bigdl.llm.transformers.bmm import SafeBMM
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue