LLM: fix device setting during saving optimized model (#10154)

This commit is contained in:
binbin Deng 2024-02-20 09:52:59 +08:00 committed by GitHub
parent 1f6d5b9f30
commit 2bb96c775c

View file

@ -59,6 +59,7 @@ def save_low_bit(self, *args, **kwargs):
delattr(self.config, "quantization_config")
delattr(self.config, "_pre_quantization_dtype")
origin_device = self.device
self.to('cpu')
kwargs['safe_serialization'] = False
@ -85,6 +86,8 @@ def save_low_bit(self, *args, **kwargs):
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
with open(os.path.join(args[0], "load_keys.json"), "w") as json_file:
json.dump(load_keys, json_file)
if origin_device != 'cpu':
self.to(origin_device)
class _BaseAutoModelClass: