LLM: fix device setting during saving optimized model (#10154)
This commit is contained in:
parent
1f6d5b9f30
commit
2bb96c775c
1 changed files with 3 additions and 0 deletions
|
|
@ -59,6 +59,7 @@ def save_low_bit(self, *args, **kwargs):
|
|||
delattr(self.config, "quantization_config")
|
||||
delattr(self.config, "_pre_quantization_dtype")
|
||||
|
||||
origin_device = self.device
|
||||
self.to('cpu')
|
||||
|
||||
kwargs['safe_serialization'] = False
|
||||
|
|
@ -85,6 +86,8 @@ def save_low_bit(self, *args, **kwargs):
|
|||
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
|
||||
with open(os.path.join(args[0], "load_keys.json"), "w") as json_file:
|
||||
json.dump(load_keys, json_file)
|
||||
if origin_device != 'cpu':
|
||||
self.to(origin_device)
|
||||
|
||||
|
||||
class _BaseAutoModelClass:
|
||||
|
|
|
|||
Loading…
Reference in a new issue