parent
e1e921f425
commit
4bda975a3e
2 changed files with 15 additions and 1 deletions
|
|
@ -406,7 +406,8 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
nn.LayerNorm,
|
nn.LayerNorm,
|
||||||
bloom_layer_norm_forward)
|
bloom_layer_norm_forward)
|
||||||
|
|
||||||
if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel":
|
if model.config.architectures is not None \
|
||||||
|
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
|
||||||
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
|
if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'):
|
||||||
# chatglm2-6b-32k
|
# chatglm2-6b-32k
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,20 @@ def save_low_bit(self, *args, **kwargs):
|
||||||
delattr(self.config, "_pre_quantization_dtype")
|
delattr(self.config, "_pre_quantization_dtype")
|
||||||
|
|
||||||
self.to('cpu')
|
self.to('cpu')
|
||||||
|
|
||||||
|
architectures = getattr(self.config, "architectures", None)
|
||||||
|
model_type = getattr(self.config, "model_type", None)
|
||||||
self.save_pretrained(*args, **kwargs)
|
self.save_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
|
if architectures:
|
||||||
|
self.config.update({"architectures": architectures})
|
||||||
|
if model_type:
|
||||||
|
self.config.update({"model_type": model_type})
|
||||||
|
|
||||||
|
self.config.save_pretrained(args[0])
|
||||||
|
if self.can_generate():
|
||||||
|
self.generation_config.save_pretrained(args[0])
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
# We conveniently save all the keys of the model to have them on hand,
|
# We conveniently save all the keys of the model to have them on hand,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue