diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 07f929c3..86fd7914 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -153,7 +153,8 @@ def optimize(model): # todo implement 4.28.0 ~ 4.30.2 pass - if "chatglm2" in model.config._name_or_path: + if "chatglm-18b" in model.config._name_or_path or "chatglm2" in model.config._name_or_path: + # chatglm-18b or chatglm2-6b modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm2 import chatglm2_attention_forward_8eb45c @@ -166,6 +167,7 @@ def optimize(model): module.CoreAttention, core_attn_forward_8eb45c) elif "chatglm" in model.config._name_or_path: + # chatglm-6b modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.chatglm import chatglm_attention_forward