LLM: fix device setting during saving optimized model (#10154)
This commit is contained in:
		
							parent
							
								
									1f6d5b9f30
								
							
						
					
					
						commit
						2bb96c775c
					
				
					 1 changed files with 3 additions and 0 deletions
				
			
		| 
						 | 
				
			
			@ -59,6 +59,7 @@ def save_low_bit(self, *args, **kwargs):
 | 
			
		|||
        delattr(self.config, "quantization_config")
 | 
			
		||||
        delattr(self.config, "_pre_quantization_dtype")
 | 
			
		||||
 | 
			
		||||
    origin_device = self.device
 | 
			
		||||
    self.to('cpu')
 | 
			
		||||
 | 
			
		||||
    kwargs['safe_serialization'] = False
 | 
			
		||||
| 
						 | 
				
			
			@ -85,6 +86,8 @@ def save_low_bit(self, *args, **kwargs):
 | 
			
		|||
    load_keys = {"all_checkpoint_keys": list(self.state_dict().keys())}
 | 
			
		||||
    with open(os.path.join(args[0], "load_keys.json"), "w") as json_file:
 | 
			
		||||
        json.dump(load_keys, json_file)
 | 
			
		||||
    if origin_device != 'cpu':
 | 
			
		||||
        self.to(origin_device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _BaseAutoModelClass:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue