diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 1f95ccbe..9ac192ce 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -59,6 +59,7 @@ def save_low_bit(self, *args, **kwargs): delattr(self.config, "quantization_config") delattr(self.config, "_pre_quantization_dtype") + origin_device = self.device self.to('cpu') kwargs['safe_serialization'] = False @@ -85,6 +86,8 @@ def save_low_bit(self, *args, **kwargs): load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())} with open(os.path.join(args[0], "load_keys.json"), "w") as json_file: json.dump(load_keys, json_file) + if origin_device != 'cpu': + self.to(origin_device) class _BaseAutoModelClass: