diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index e0fb12a1..1ff2f805 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -368,8 +368,6 @@ class MatMulLowBitCPU(torch.autograd.Function): @staticmethod def forward(ctx, A, weight): - if torch.is_autocast_enabled(): - A = A.to(torch.get_autocast_dtype()) ctx.is_empty = False x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape, weight._shape[0] * weight._shape[1]) diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 5c271e88..e50f196f 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -142,7 +142,7 @@ def get_autocast_dtype(x): else: return None elif x.device.type == "cpu": - if torch.is_autocast_enabled(): + if torch.is_autocast_cpu_enabled(): return torch.get_autocast_cpu_dtype() else: return None