From 5e46e0e5afa0bd096af107a614f759f088e0aade Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Thu, 14 Dec 2023 09:58:32 +0800 Subject: [PATCH] fix baichuan2-7b 1st token performance regression on xpu (#9683) * fix baichuan2-7b 1st token performance regression * add comments * fix style --- python/llm/src/bigdl/llm/transformers/low_bit_linear.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 17f0c4e3..1a78a4fe 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -437,7 +437,10 @@ class LowBitLinear(nn.Linear): self.compute_dtype = None # only for training def forward(self, x: torch.Tensor): - if self.training: + # Due to inconsistent training status in some models like Baichuan-7b-Chat, + # we should check both self.training and torch.is_inference_mode_enabled(). + is_training = self.training and not torch.is_inference_mode_enabled() + if is_training: # below logic is only for training autocast_dtype = get_autocast_dtype(x) if self.compute_dtype is not None and x.device.type == "xpu": @@ -476,7 +479,7 @@ class LowBitLinear(nn.Linear): x_2d = x_2d.contiguous() input_seq_size = x_shape[1] - if self.training: + if is_training: # training path if x_2d.requires_grad: result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)