diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 6fb465c6..88b5e94d 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -50,6 +50,7 @@ from torch import Tensor, device, dtype, nn from operator import mul from functools import reduce from bigdl.llm.transformers.xpu_customize_fwd import custom_fwd, custom_bwd +from bigdl.llm.transformers.utils import get_autocast_dtype T = TypeVar("T", bound="torch.nn.Module") @@ -433,8 +434,17 @@ class LowBitLinear(nn.Linear): self.qtype = qtype self.conver_to_half = conver_to_half self.mp_group = mp_group + self.compute_dtype = None # only for training def forward(self, x: torch.Tensor): + if self.training: + # below logic is only for training + autocast_dtype = get_autocast_dtype(x) + if self.compute_dtype is not None and x.device.type == "xpu": + x = x.to(self.compute_dtype) # solve GC issue for unlora module + elif autocast_dtype is not None: + x = x.to(autocast_dtype) + if self.bias is not None and self.bias.dtype != x.dtype: self.bias.data = self.bias.data.to(x.dtype) @@ -457,9 +467,16 @@ class LowBitLinear(nn.Linear): x_2d = x_2d.contiguous() input_seq_size = x_shape[1] - if self.training and x_2d.requires_grad: - result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) + if self.training: + # training path + if x_2d.requires_grad: + result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) + else: + result = linear_q4_0.forward_new(x_2d, self.weight.data, + self.weight.qtype, + input_seq_size) else: + # inference path # current workaround to reduce first token latency of fp32 input # sometimes fp16 cause nan and training instability # disable the conversion when training diff --git a/python/llm/src/bigdl/llm/transformers/qlora.py b/python/llm/src/bigdl/llm/transformers/qlora.py index 53916ed9..f507755f 100644 --- a/python/llm/src/bigdl/llm/transformers/qlora.py +++ b/python/llm/src/bigdl/llm/transformers/qlora.py @@ -391,6 +391,8 @@ TrainingArguments._setup_devices = _setup_devices def cast_lora_weight(model, dtype=torch.bfloat16): for name, module in model.named_modules(): + if isinstance(module, LowBitLinear): + module.compute_dtype = dtype if isinstance(module, LoraLayer): module = module.to(dtype) if 'norm' in name: