LLM: Align lowbit model config (#9735)

* align lowbit model config
This commit is contained in:
Zhao Changmin 2023-12-21 09:48:58 +08:00 committed by GitHub
parent e1e921f425
commit 4bda975a3e
2 changed files with 15 additions and 1 deletions

View file

@ -406,7 +406,8 @@ def _optimize_post(model, lightweight_bmm=False):
nn.LayerNorm, nn.LayerNorm,
bloom_layer_norm_forward) bloom_layer_norm_forward)
if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": if model.config.architectures is not None \
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'): if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
# chatglm2-6b-32k # chatglm2-6b-32k
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__

View file

@ -60,7 +60,20 @@ def save_low_bit(self, *args, **kwargs):
delattr(self.config, "_pre_quantization_dtype") delattr(self.config, "_pre_quantization_dtype")
self.to('cpu') self.to('cpu')
architectures = getattr(self.config, "architectures", None)
model_type = getattr(self.config, "model_type", None)
self.save_pretrained(*args, **kwargs) self.save_pretrained(*args, **kwargs)
if architectures:
self.config.update({"architectures": architectures})
if model_type:
self.config.update({"model_type": model_type})
self.config.save_pretrained(args[0])
if self.can_generate():
self.generation_config.save_pretrained(args[0])
import json import json
import os import os
# We conveniently save all the keys of the model to have them on hand, # We conveniently save all the keys of the model to have them on hand,