LLM: support chatglm-18b convert attention forward in benchmark scripts. (#9072)

* add chatglm-18b convert.

* fix if statement.

* fix
This commit is contained in:
Cengguang Zhang 2023-09-28 14:04:52 +08:00 committed by GitHub
parent 6de2189e90
commit fb883100e7

View file

@ -153,7 +153,8 @@ def optimize(model):
# todo implement 4.28.0 ~ 4.30.2 # todo implement 4.28.0 ~ 4.30.2
pass pass
if "chatglm2" in model.config._name_or_path: if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path:
# chatglm-18b or chatglm2-6b
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
@ -166,6 +167,7 @@ def optimize(model):
module.CoreAttention, module.CoreAttention,
core_attn_forward_8eb45c) core_attn_forward_8eb45c)
elif "chatglm" in model.config._name_or_path: elif "chatglm" in model.config._name_or_path:
# chatglm-6b
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward