LLM: fix device setting during saving optimized model (#10154)
This commit is contained in:
parent
1f6d5b9f30
commit
2bb96c775c
1 changed files with 3 additions and 0 deletions
|
|
@ -59,6 +59,7 @@ def save_low_bit(self, *args, **kwargs):
|
||||||
delattr(self.config, "quantization_config")
|
delattr(self.config, "quantization_config")
|
||||||
delattr(self.config, "_pre_quantization_dtype")
|
delattr(self.config, "_pre_quantization_dtype")
|
||||||
|
|
||||||
|
origin_device = self.device
|
||||||
self.to('cpu')
|
self.to('cpu')
|
||||||
|
|
||||||
kwargs['safe_serialization'] = False
|
kwargs['safe_serialization'] = False
|
||||||
|
|
@ -85,6 +86,8 @@ def save_low_bit(self, *args, **kwargs):
|
||||||
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
|
load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
|
||||||
with open(os.path.join(args[0], "load_keys.json"), "w") as json_file:
|
with open(os.path.join(args[0], "load_keys.json"), "w") as json_file:
|
||||||
json.dump(load_keys, json_file)
|
json.dump(load_keys, json_file)
|
||||||
|
if origin_device != 'cpu':
|
||||||
|
self.to(origin_device)
|
||||||
|
|
||||||
|
|
||||||
class _BaseAutoModelClass:
|
class _BaseAutoModelClass:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue