From 40cead6b5b3706c219e06df5278baa0857cabf74 Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Thu, 9 Nov 2023 14:34:01 +0800 Subject: [PATCH] LLM: Fix CPU qlora dtype convert issue (#9394) --- python/llm/src/bigdl/llm/transformers/low_bit_linear.py | 2 -- python/llm/src/bigdl/llm/transformers/utils.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index e0fb12a1..1ff2f805 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -368,8 +368,6 @@ class MatMulLowBitCPU(torch.autograd.Function): @staticmethod def forward(ctx, A, weight): - if torch.is_autocast_enabled(): - A = A.to(torch.get_autocast_dtype()) ctx.is_empty = False x0_fp32 = ggml_int4_convert_fp32(weight.data, weight._shape, weight._shape[0] * weight._shape[1]) diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 5c271e88..e50f196f 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -142,7 +142,7 @@ def get_autocast_dtype(x): else: return None elif x.device.type == "cpu": - if torch.is_autocast_enabled(): + if torch.is_autocast_cpu_enabled(): return torch.get_autocast_cpu_dtype() else: return None