small fix (#12727)
This commit is contained in:
		
							parent
							
								
									412bfd6644
								
							
						
					
					
						commit
						6789e5d92f
					
				
					 1 changed files with 16 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -139,19 +139,25 @@ def fix_key(key):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
def get_autocast_dtype(x):
 | 
			
		||||
    if x.device.type == "xpu":
 | 
			
		||||
        if torch.xpu.is_autocast_xpu_enabled():
 | 
			
		||||
            return torch.xpu.get_autocast_xpu_dtype()
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
    elif x.device.type == "cpu":
 | 
			
		||||
        if torch.is_autocast_cpu_enabled():
 | 
			
		||||
            return torch.get_autocast_cpu_dtype()
 | 
			
		||||
    if torch.__version__ >= '2.3':
 | 
			
		||||
        if torch.is_autocast_enabled(x.device.type):
 | 
			
		||||
            return torch.get_autocast_dtype(x.device.type)
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False,
 | 
			
		||||
                          f"Device {x.device} is not supported.")
 | 
			
		||||
        if x.device.type == "xpu":
 | 
			
		||||
            if torch.xpu.is_autocast_xpu_enabled():
 | 
			
		||||
                return torch.xpu.get_autocast_xpu_dtype()
 | 
			
		||||
            else:
 | 
			
		||||
                return None
 | 
			
		||||
        elif x.device.type == "cpu":
 | 
			
		||||
            if torch.is_autocast_cpu_enabled():
 | 
			
		||||
                return torch.get_autocast_cpu_dtype()
 | 
			
		||||
            else:
 | 
			
		||||
                return None
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              f"Device {x.device} is not supported.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_xpu_device_name(device: torch.device):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue