From 0b6f29a7fcc1c0f5596877c2bc675814a215809e Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Fri, 8 Dec 2023 16:04:38 +0800 Subject: [PATCH] add fused rms norm for Yi and Qwen (#9640) --- python/llm/src/bigdl/llm/transformers/convert.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 98078146..92944992 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -580,10 +580,14 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.qwen import qwen_attention_forward + from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward convert_forward(model, module.QWenAttention, qwen_attention_forward ) + convert_forward(model, + module.RMSNorm, + chatglm_rms_norm_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -606,6 +610,12 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MistralRMSNorm, llama_rms_norm_forward) + elif model.config.model_type == "Yi": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward(model, + module.YiRMSNorm, + llama_rms_norm_forward) elif model.config.model_type == "whisper" and lightweight_bmm: if platform.system().lower() == 'windows': from bigdl.llm.transformers.bmm import SafeBMM