LLM: support chatglm-18b convert attention forward in benchmark scripts. (#9072)
* add chatglm-18b convert. * fix if statement. * fix
This commit is contained in:
parent
6de2189e90
commit
fb883100e7
1 changed files with 3 additions and 1 deletions
|
|
@ -153,7 +153,8 @@ def optimize(model):
|
||||||
# todo implement 4.28.0 ~ 4.30.2
|
# todo implement 4.28.0 ~ 4.30.2
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if "chatglm2" in model.config._name_or_path:
|
if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path:
|
||||||
|
# chatglm-18b or chatglm2-6b
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c
|
||||||
|
|
@ -166,6 +167,7 @@ def optimize(model):
|
||||||
module.CoreAttention,
|
module.CoreAttention,
|
||||||
core_attn_forward_8eb45c)
|
core_attn_forward_8eb45c)
|
||||||
elif "chatglm" in model.config._name_or_path:
|
elif "chatglm" in model.config._name_or_path:
|
||||||
|
# chatglm-6b
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
|
from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue