parent
e1e921f425
commit
4bda975a3e
2 changed files with 15 additions and 1 deletions
|
|
@ -406,7 +406,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
nn.LayerNorm,
|
||||
bloom_layer_norm_forward)
|
||||
|
||||
if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
||||
if model.config.architectures is not None \
|
||||
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
|
||||
# chatglm2-6b-32k
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
|
|
@ -60,7 +60,20 @@ def save_low_bit(self, *args, **kwargs):
|
|||
delattr(self.config, "_pre_quantization_dtype")
|
||||
|
||||
self.to('cpu')
|
||||
|
||||
architectures = getattr(self.config, "architectures", None)
|
||||
model_type = getattr(self.config, "model_type", None)
|
||||
self.save_pretrained(*args, **kwargs)
|
||||
|
||||
if architectures:
|
||||
self.config.update({"architectures": architectures})
|
||||
if model_type:
|
||||
self.config.update({"model_type": model_type})
|
||||
|
||||
self.config.save_pretrained(args[0])
|
||||
if self.can_generate():
|
||||
self.generation_config.save_pretrained(args[0])
|
||||
|
||||
import json
|
||||
import os
|
||||
# We conveniently save all the keys of the model to have them on hand,
|
||||
|
|
|
|||
Loading…
Reference in a new issue