fix chatglm new model (#11639)

This commit is contained in:
Yishuo Wang 2024-07-23 13:44:56 +08:00 committed by GitHub
parent 7f80db95eb
commit 1b3b46e54d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1035,7 +1035,7 @@ def _optimize_post(model, lightweight_bmm=False):
if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
if hasattr(model.config, 'padded_vocab_size') and \
model.config.padded_vocab_size == 65024:
model.config.padded_vocab_size in [65024, 64896]:
# chatglm2-6b, chatglm2-6b-32k, chatglm3-6b, chatglm3-6b-32k, chatglm3-6b-128k
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)