small fix (#12727)
This commit is contained in:
parent
412bfd6644
commit
6789e5d92f
1 changed files with 16 additions and 10 deletions
|
|
@ -139,6 +139,12 @@ def fix_key(key):
|
||||||
|
|
||||||
|
|
||||||
def get_autocast_dtype(x):
|
def get_autocast_dtype(x):
|
||||||
|
if torch.__version__ >= '2.3':
|
||||||
|
if torch.is_autocast_enabled(x.device.type):
|
||||||
|
return torch.get_autocast_dtype(x.device.type)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
if x.device.type == "xpu":
|
if x.device.type == "xpu":
|
||||||
if torch.xpu.is_autocast_xpu_enabled():
|
if torch.xpu.is_autocast_xpu_enabled():
|
||||||
return torch.xpu.get_autocast_xpu_dtype()
|
return torch.xpu.get_autocast_xpu_dtype()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue