From 69942d3826a1f200dce0dd4cdf17048f124793a3 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:21:51 +0800 Subject: [PATCH] LLM: fix model check before attention optimization (#9149) --- .../llm/src/bigdl/llm/transformers/convert.py | 47 ++++++++++--------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index e0c76233..5d530405 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -181,29 +181,30 @@ def optimize(model): # todo implement 4.28.0 ~ 4.30.2 pass - if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path: - # chatglm-18b or chatglm2-6b - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c - from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c - convert_forward(model, - module.SelfAttention, - chatglm2_attention_forward_8eb45c - ) - convert_forward(model, - module.CoreAttention, - core_attn_forward_8eb45c) - elif "chatglm" in model.config._name_or_path: - # chatglm-6b - modeling_module_name = model.__class__.__module__ - module = importlib.import_module(modeling_module_name) - from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward - convert_forward(model, - module.SelfAttention, - chatglm_attention_forward - ) - elif "mpt" in model.config._name_or_path: + if model.config.architectures[0] == "ChatGLMModel": + if hasattr(model.config, "padded_vocab_size") and model.config.padded_vocab_size == 65024: + # chatglm2-6b + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c + from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c + convert_forward(model, + module.SelfAttention, + chatglm2_attention_forward_8eb45c + ) + convert_forward(model, + module.CoreAttention, + core_attn_forward_8eb45c) + elif model.config.vocab_size == 130528: + # chatglm-6b + modeling_module_name = model.__class__.__module__ + module = importlib.import_module(modeling_module_name) + from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward + convert_forward(model, + module.SelfAttention, + chatglm_attention_forward + ) + elif "mpt" in model.config.model_type: modeling_module_name = model.__class__.__module__ attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention" module = importlib.import_module(attention_module_name)