small fix (#12727)
This commit is contained in:
		
							parent
							
								
									412bfd6644
								
							
						
					
					
						commit
						6789e5d92f
					
				
					 1 changed files with 16 additions and 10 deletions
				
			
		| 
						 | 
					@ -139,19 +139,25 @@ def fix_key(key):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_autocast_dtype(x):
 | 
					def get_autocast_dtype(x):
 | 
				
			||||||
    if x.device.type == "xpu":
 | 
					    if torch.__version__ >= '2.3':
 | 
				
			||||||
        if torch.xpu.is_autocast_xpu_enabled():
 | 
					        if torch.is_autocast_enabled(x.device.type):
 | 
				
			||||||
            return torch.xpu.get_autocast_xpu_dtype()
 | 
					            return torch.get_autocast_dtype(x.device.type)
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return None
 | 
					 | 
				
			||||||
    elif x.device.type == "cpu":
 | 
					 | 
				
			||||||
        if torch.is_autocast_cpu_enabled():
 | 
					 | 
				
			||||||
            return torch.get_autocast_cpu_dtype()
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        invalidInputError(False,
 | 
					        if x.device.type == "xpu":
 | 
				
			||||||
                          f"Device {x.device} is not supported.")
 | 
					            if torch.xpu.is_autocast_xpu_enabled():
 | 
				
			||||||
 | 
					                return torch.xpu.get_autocast_xpu_dtype()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					        elif x.device.type == "cpu":
 | 
				
			||||||
 | 
					            if torch.is_autocast_cpu_enabled():
 | 
				
			||||||
 | 
					                return torch.get_autocast_cpu_dtype()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            invalidInputError(False,
 | 
				
			||||||
 | 
					                              f"Device {x.device} is not supported.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_xpu_device_name(device: torch.device):
 | 
					def get_xpu_device_name(device: torch.device):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue