This commit is contained in:
Yishuo Wang 2025-01-21 17:27:18 +08:00 committed by GitHub
parent 412bfd6644
commit 6789e5d92f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -139,6 +139,12 @@ def fix_key(key):
def get_autocast_dtype(x):
if torch.__version__ >= '2.3':
if torch.is_autocast_enabled(x.device.type):
return torch.get_autocast_dtype(x.device.type)
else:
return None
else:
if x.device.type == "xpu":
if torch.xpu.is_autocast_xpu_enabled():
return torch.xpu.get_autocast_xpu_dtype()