From 4bda975a3e3f89c28f3ff11f0d0a69f79eff01a0 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Thu, 21 Dec 2023 09:48:58 +0800 Subject: [PATCH] LLM: Align lowbit model config (#9735) * align lowbit model config --- python/llm/src/bigdl/llm/transformers/convert.py | 3 ++- python/llm/src/bigdl/llm/transformers/model.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 95d98213..62b74ae3 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -406,7 +406,8 @@ def _optimize_post(model, lightweight_bmm=False): nn.LayerNorm, bloom_layer_norm_forward) - if model.config.architectures is not None and model.config.architectures[0] == "ChatGLMModel": + if model.config.architectures is not None \ + and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]: if model.config.num_layers == 28 and hasattr(model.config, 'rope_ratio'): # chatglm2-6b-32k modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 68d2ef54..3b01d869 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -60,7 +60,20 @@ def save_low_bit(self, *args, **kwargs): delattr(self.config, "_pre_quantization_dtype") self.to('cpu') + + architectures = getattr(self.config, "architectures", None) + model_type = getattr(self.config, "model_type", None) self.save_pretrained(*args, **kwargs) + + if architectures: + self.config.update({"architectures": architectures}) + if model_type: + self.config.update({"model_type": model_type}) + + self.config.save_pretrained(args[0]) + if self.can_generate(): + self.generation_config.save_pretrained(args[0]) + import json import os # We conveniently save all the keys of the model to have them on hand,