diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 98078146..92944992 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -580,10 +580,14 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.qwen import qwen_attention_forward + from bigdl.llm.transformers.models.chatglm2 import chatglm_rms_norm_forward convert_forward(model, module.QWenAttention, qwen_attention_forward ) + convert_forward(model, + module.RMSNorm, + chatglm_rms_norm_forward) elif model.config.model_type == "aquila": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) @@ -606,6 +610,12 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MistralRMSNorm, llama_rms_norm_forward) + elif model.config.model_type == "Yi": + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + convert_forward(model, + module.YiRMSNorm, + llama_rms_norm_forward) elif model.config.model_type == "whisper" and lightweight_bmm: if platform.system().lower() == 'windows': from bigdl.llm.transformers.bmm import SafeBMM