From 2bb96c775cbf9f55c6d1c1b29bfe079f66fd3f51 Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Tue, 20 Feb 2024 09:52:59 +0800 Subject: [PATCH] LLM: fix device setting during saving optimized model (#10154) --- python/llm/src/bigdl/llm/transformers/model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 1f95ccbe..9ac192ce 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -59,6 +59,7 @@ def save_low_bit(self, *args, **kwargs): delattr(self.config, "quantization_config") delattr(self.config, "_pre_quantization_dtype") + origin_device = self.device self.to('cpu') kwargs['safe_serialization'] = False @@ -85,6 +86,8 @@ def save_low_bit(self, *args, **kwargs): load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())} with open(os.path.join(args[0], "load_keys.json"), "w") as json_file: json.dump(load_keys, json_file) + if origin_device != 'cpu': + self.to(origin_device) class _BaseAutoModelClass: