small fix (#12727)
This commit is contained in:
parent
412bfd6644
commit
6789e5d92f
1 changed files with 16 additions and 10 deletions
|
|
@ -139,19 +139,25 @@ def fix_key(key):
|
|||
|
||||
|
||||
def get_autocast_dtype(x):
|
||||
if x.device.type == "xpu":
|
||||
if torch.xpu.is_autocast_xpu_enabled():
|
||||
return torch.xpu.get_autocast_xpu_dtype()
|
||||
else:
|
||||
return None
|
||||
elif x.device.type == "cpu":
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
return torch.get_autocast_cpu_dtype()
|
||||
if torch.__version__ >= '2.3':
|
||||
if torch.is_autocast_enabled(x.device.type):
|
||||
return torch.get_autocast_dtype(x.device.type)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"Device {x.device} is not supported.")
|
||||
if x.device.type == "xpu":
|
||||
if torch.xpu.is_autocast_xpu_enabled():
|
||||
return torch.xpu.get_autocast_xpu_dtype()
|
||||
else:
|
||||
return None
|
||||
elif x.device.type == "cpu":
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
return torch.get_autocast_cpu_dtype()
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
invalidInputError(False,
|
||||
f"Device {x.device} is not supported.")
|
||||
|
||||
|
||||
def get_xpu_device_name(device: torch.device):
|
||||
|
|
|
|||
Loading…
Reference in a new issue