LLM: fix model check before attention optimization (#9149)

This commit is contained in:
binbin Deng 2023-10-12 15:21:51 +08:00 committed by GitHub
parent 1a1ddc4144
commit 69942d3826

View file

@ -181,29 +181,30 @@ def optimize(model):
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass
if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path: if model.config.architectures[0] == "ChatGLMModel":
# chatglm-18b or chatglm2-6b if hasattr(model.config, "padded_vocab_size") and model.config.padded_vocab_size == 65024:
modeling_module_name = model.__class__.__module__ # chatglm2-6b
module = importlib.import_module(modeling_module_name) modeling_module_name = model.__class__.__module__
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
convert_forward(model, from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
module.SelfAttention, convert_forward(model,
chatglm2_attention_forward_8eb45c module.SelfAttention,
) chatglm2_attention_forward_8eb45c
convert_forward(model, )
module.CoreAttention, convert_forward(model,
core_attn_forward_8eb45c) module.CoreAttention,
elif "chatglm" in model.config._name_or_path: core_attn_forward_8eb45c)
# chatglm-6b elif model.config.vocab_size == 130528:
modeling_module_name = model.__class__.__module__ # chatglm-6b
module = importlib.import_module(modeling_module_name) modeling_module_name = model.__class__.__module__
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward module = importlib.import_module(modeling_module_name)
convert_forward(model, from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
module.SelfAttention, convert_forward(model,
chatglm_attention_forward module.SelfAttention,
) chatglm_attention_forward
elif "mpt" in model.config._name_or_path: )
elif "mpt" in model.config.model_type:
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention" attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
module = importlib.import_module(attention_module_name) module = importlib.import_module(attention_module_name)