diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 6715dbce..e0fb12a1 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -336,8 +336,6 @@ class MatMulLowBit(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, A, weight, input_seq_size): - if torch.xpu.is_autocast_xpu_enabled(): - A = A.to(torch.xpu.get_autocast_xpu_dtype()) ctx.is_empty = False import linear_q4_0 result = linear_q4_0.forward_new(A, weight.data, weight.qtype, input_seq_size) @@ -448,8 +446,6 @@ class LowBitLinear(nn.Linear): input_seq_size) result = result.to(x.dtype) else: - if torch.xpu.is_autocast_xpu_enabled(): - x_2d = x_2d.to(torch.xpu.get_autocast_xpu_dtype()) result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype, input_seq_size) new_shape = x_shape[:-1] + (self.out_len,) diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index cb6b5d36..35b1f0ef 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -52,6 +52,7 @@ import torch from bigdl.llm.transformers.low_bit_linear import LowBitLinear from peft.tuners.lora import LoraLayer from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.utils import get_autocast_dtype import functools @@ -85,13 +86,16 @@ class LoraLowBitLinear(LowBitLinear, LoraLayer): self.active_adapter = adapter_name def forward(self, x: torch.Tensor): + autocast_dtype = get_autocast_dtype(x) + if autocast_dtype is not None: + x = x.to(autocast_dtype) result = super().forward(x) if self.disable_adapters or self.active_adapter not in self.lora_A.keys(): return result elif self.r[self.active_adapter] > 0: result = result.clone() - if not torch.is_autocast_enabled(): + if autocast_dtype is None: expected_dtype = result.dtype x = x.to(self.lora_A[self.active_adapter].weight.dtype) output = ( diff --git a/python/llm/src/bigdl/llm/transformers/utils.py b/python/llm/src/bigdl/llm/transformers/utils.py index 499765e1..5c271e88 100644 --- a/python/llm/src/bigdl/llm/transformers/utils.py +++ b/python/llm/src/bigdl/llm/transformers/utils.py @@ -133,3 +133,19 @@ def fix_key(key): if "gamma" in key: return key.replace("gamma", "weight") return key + + +def get_autocast_dtype(x): + if x.device.type == "xpu": + if torch.xpu.is_autocast_xpu_enabled(): + return torch.xpu.get_autocast_xpu_dtype() + else: + return None + elif x.device.type == "cpu": + if torch.is_autocast_enabled(): + return torch.get_autocast_cpu_dtype() + else: + return None + else: + invalidInputError(False, + f"Device {x.device} is not supported.")