From fb883100e773740f1995b02e6d7c0c1997858771 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Thu, 28 Sep 2023 14:04:52 +0800 Subject: [PATCH] LLM: support chatglm-18b convert attention forward in benchmark scripts. (#9072) * add chatglm-18b convert. * fix if statement. * fix --- python/llm/src/bigdl/llm/transformers/convert.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 07f929c3..86fd7914 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -153,7 +153,8 @@ def optimize(model): # todo implement 4.28.0 ~ 4.30.2 pass - if "chatglm2" in model.config._name_or_path: + if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path: + # chatglm-18b or chatglm2-6b modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c @@ -166,6 +167,7 @@ def optimize(model): module.CoreAttention, core_attn_forward_8eb45c) elif "chatglm" in model.config._name_or_path: + # chatglm-6b modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward