LLM: fix model check before attention optimization (#9149)

This commit is contained in:
binbin Deng 2023-10-12 15:21:51 +08:00 committed by GitHub
parent 1a1ddc4144
commit 69942d3826

View file

@ -181,29 +181,30 @@ def optimize(model):
# todo implement 4.28.0 ~ 4.30.2
pass
if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path:
# chatglm-18b or chatglm2-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
convert_forward(model,
module.SelfAttention,
chatglm2_attention_forward_8eb45c
)
convert_forward(model,
module.CoreAttention,
core_attn_forward_8eb45c)
elif "chatglm" in model.config._name_or_path:
# chatglm-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
convert_forward(model,
module.SelfAttention,
chatglm_attention_forward
)
elif "mpt" in model.config._name_or_path:
if model.config.architectures[0] == "ChatGLMModel":
if hasattr(model.config, "padded_vocab_size") and model.config.padded_vocab_size == 65024:
# chatglm2-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
convert_forward(model,
module.SelfAttention,
chatglm2_attention_forward_8eb45c
)
convert_forward(model,
module.CoreAttention,
core_attn_forward_8eb45c)
elif model.config.vocab_size == 130528:
# chatglm-6b
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
convert_forward(model,
module.SelfAttention,
chatglm_attention_forward
)
elif "mpt" in model.config.model_type:
modeling_module_name = model.__class__.__module__
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
module = importlib.import_module(attention_module_name)