LLM: fix model check before attention optimization (#9149)
This commit is contained in:
parent
1a1ddc4144
commit
69942d3826
1 changed files with 24 additions and 23 deletions
|
|
@ -181,8 +181,9 @@ def optimize(model):
|
|||
# todo implement 4.28.0 ~ 4.30.2
|
||||
pass
|
||||
|
||||
if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path:
|
||||
# chatglm-18b or chatglm2-6b
|
||||
if model.config.architectures[0] == "ChatGLMModel":
|
||||
if hasattr(model.config, "padded_vocab_size") and model.config.padded_vocab_size == 65024:
|
||||
# chatglm2-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||
|
|
@ -194,7 +195,7 @@ def optimize(model):
|
|||
convert_forward(model,
|
||||
module.CoreAttention,
|
||||
core_attn_forward_8eb45c)
|
||||
elif "chatglm" in model.config._name_or_path:
|
||||
elif model.config.vocab_size == 130528:
|
||||
# chatglm-6b
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
|
|
@ -203,7 +204,7 @@ def optimize(model):
|
|||
module.SelfAttention,
|
||||
chatglm_attention_forward
|
||||
)
|
||||
elif "mpt" in model.config._name_or_path:
|
||||
elif "mpt" in model.config.model_type:
|
||||
modeling_module_name = model.__class__.__module__
|
||||
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
|
||||
module = importlib.import_module(attention_module_name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue