From 6789e5d92f4c8cbd6f5734b512770564e3f15c29 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 21 Jan 2025 17:27:18 +0800 Subject: [PATCH] small fix (#12727) --- python/llm/src/ipex_llm/transformers/utils.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index 329a3a4b..5bd24667 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -139,19 +139,25 @@ def fix_key(key): def get_autocast_dtype(x): - if x.device.type == "xpu": - if torch.xpu.is_autocast_xpu_enabled(): - return torch.xpu.get_autocast_xpu_dtype() - else: - return None - elif x.device.type == "cpu": - if torch.is_autocast_cpu_enabled(): - return torch.get_autocast_cpu_dtype() + if torch.__version__ >= '2.3': + if torch.is_autocast_enabled(x.device.type): + return torch.get_autocast_dtype(x.device.type) else: return None else: - invalidInputError(False, - f"Device {x.device} is not supported.") + if x.device.type == "xpu": + if torch.xpu.is_autocast_xpu_enabled(): + return torch.xpu.get_autocast_xpu_dtype() + else: + return None + elif x.device.type == "cpu": + if torch.is_autocast_cpu_enabled(): + return torch.get_autocast_cpu_dtype() + else: + return None + else: + invalidInputError(False, + f"Device {x.device} is not supported.") def get_xpu_device_name(device: torch.device):