LLM: fix model check before attention optimization (#9149)
This commit is contained in:
parent
1a1ddc4144
commit
69942d3826
1 changed files with 24 additions and 23 deletions
|
|
@ -181,29 +181,30 @@ def optimize(model):
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path:
|
if model.config.architectures[0] == "ChatGLMModel":
|
||||||
# chatglm-18b or chatglm2-6b
|
if hasattr(model.config, "padded_vocab_size") and model.config.padded_vocab_size == 65024:
|
||||||
modeling_module_name = model.__class__.__module__
|
# chatglm2-6b
|
||||||
module = importlib.import_module(modeling_module_name)
|
modeling_module_name = model.__class__.__module__
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||||
convert_forward(model,
|
from bigdl.llm.transformers.models.chatglm2 import core_attn_forward_8eb45c
|
||||||
module.SelfAttention,
|
convert_forward(model,
|
||||||
chatglm2_attention_forward_8eb45c
|
module.SelfAttention,
|
||||||
)
|
chatglm2_attention_forward_8eb45c
|
||||||
convert_forward(model,
|
)
|
||||||
module.CoreAttention,
|
convert_forward(model,
|
||||||
core_attn_forward_8eb45c)
|
module.CoreAttention,
|
||||||
elif "chatglm" in model.config._name_or_path:
|
core_attn_forward_8eb45c)
|
||||||
# chatglm-6b
|
elif model.config.vocab_size == 130528:
|
||||||
modeling_module_name = model.__class__.__module__
|
# chatglm-6b
|
||||||
module = importlib.import_module(modeling_module_name)
|
modeling_module_name = model.__class__.__module__
|
||||||
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
|
module = importlib.import_module(modeling_module_name)
|
||||||
convert_forward(model,
|
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
|
||||||
module.SelfAttention,
|
convert_forward(model,
|
||||||
chatglm_attention_forward
|
module.SelfAttention,
|
||||||
)
|
chatglm_attention_forward
|
||||||
elif "mpt" in model.config._name_or_path:
|
)
|
||||||
|
elif "mpt" in model.config.model_type:
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
|
attention_module_name = '.'.join(modeling_module_name.split('.')[:-1]) + ".attention"
|
||||||
module = importlib.import_module(attention_module_name)
|
module = importlib.import_module(attention_module_name)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue