This commit is contained in:
Yishuo Wang 2025-01-21 17:27:18 +08:00 committed by GitHub
parent 412bfd6644
commit 6789e5d92f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -139,19 +139,25 @@ def fix_key(key):
def get_autocast_dtype(x):
if x.device.type == "xpu":
if torch.xpu.is_autocast_xpu_enabled():
return torch.xpu.get_autocast_xpu_dtype()
else:
return None
elif x.device.type == "cpu":
if torch.is_autocast_cpu_enabled():
return torch.get_autocast_cpu_dtype()
if torch.__version__ >= '2.3':
if torch.is_autocast_enabled(x.device.type):
return torch.get_autocast_dtype(x.device.type)
else:
return None
else:
invalidInputError(False,
f"Device {x.device} is not supported.")
if x.device.type == "xpu":
if torch.xpu.is_autocast_xpu_enabled():
return torch.xpu.get_autocast_xpu_dtype()
else:
return None
elif x.device.type == "cpu":
if torch.is_autocast_cpu_enabled():
return torch.get_autocast_cpu_dtype()
else:
return None
else:
invalidInputError(False,
f"Device {x.device} is not supported.")
def get_xpu_device_name(device: torch.device):