small fix (#12727)
This commit is contained in:
parent
412bfd6644
commit
6789e5d92f
1 changed files with 16 additions and 10 deletions
|
|
@ -139,19 +139,25 @@ def fix_key(key):
|
||||||
|
|
||||||
|
|
||||||
def get_autocast_dtype(x):
|
def get_autocast_dtype(x):
|
||||||
if x.device.type == "xpu":
|
if torch.__version__ >= '2.3':
|
||||||
if torch.xpu.is_autocast_xpu_enabled():
|
if torch.is_autocast_enabled(x.device.type):
|
||||||
return torch.xpu.get_autocast_xpu_dtype()
|
return torch.get_autocast_dtype(x.device.type)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
elif x.device.type == "cpu":
|
|
||||||
if torch.is_autocast_cpu_enabled():
|
|
||||||
return torch.get_autocast_cpu_dtype()
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
invalidInputError(False,
|
if x.device.type == "xpu":
|
||||||
f"Device {x.device} is not supported.")
|
if torch.xpu.is_autocast_xpu_enabled():
|
||||||
|
return torch.xpu.get_autocast_xpu_dtype()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
elif x.device.type == "cpu":
|
||||||
|
if torch.is_autocast_cpu_enabled():
|
||||||
|
return torch.get_autocast_cpu_dtype()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Device {x.device} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
def get_xpu_device_name(device: torch.device):
|
def get_xpu_device_name(device: torch.device):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue