diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 17f0c4e3..1a78a4fe 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -437,7 +437,10 @@ class LowBitLinear(nn.Linear): self.compute_dtype = None # only for training def forward(self, x: torch.Tensor): - if self.training: + # Due to inconsistent training status in some models like Baichuan-7b-Chat, + # we should check both self.training and torch.is_inference_mode_enabled(). + is_training = self.training and not torch.is_inference_mode_enabled() + if is_training: # below logic is only for training autocast_dtype = get_autocast_dtype(x) if self.compute_dtype is not None and x.device.type == "xpu": @@ -476,7 +479,7 @@ class LowBitLinear(nn.Linear): x_2d = x_2d.contiguous() input_seq_size = x_shape[1] - if self.training: + if is_training: # training path if x_2d.requires_grad: result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)