fix baichuan2-7b 1st token performance regression on xpu (#9683)

* fix baichuan2-7b 1st token performance regression

* add comments

* fix style
This commit is contained in:
Xin Qiu 2023-12-14 09:58:32 +08:00 committed by GitHub
parent 877229f3be
commit 5e46e0e5af

View file

@ -437,7 +437,10 @@ class LowBitLinear(nn.Linear):
self.compute_dtype = None # only for training self.compute_dtype = None # only for training
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
if self.training: # Due to inconsistent training status in some models like Baichuan-7b-Chat,
# we should check both self.training and torch.is_inference_mode_enabled().
is_training = self.training and not torch.is_inference_mode_enabled()
if is_training:
# below logic is only for training # below logic is only for training
autocast_dtype = get_autocast_dtype(x) autocast_dtype = get_autocast_dtype(x)
if self.compute_dtype is not None and x.device.type == "xpu": if self.compute_dtype is not None and x.device.type == "xpu":
@ -476,7 +479,7 @@ class LowBitLinear(nn.Linear):
x_2d = x_2d.contiguous() x_2d = x_2d.contiguous()
input_seq_size = x_shape[1] input_seq_size = x_shape[1]
if self.training: if is_training:
# training path # training path
if x_2d.requires_grad: if x_2d.requires_grad:
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size) result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)