fix baichuan2-7b 1st token performance regression on xpu (#9683)
* fix baichuan2-7b 1st token performance regression * add comments * fix style
This commit is contained in:
parent
877229f3be
commit
5e46e0e5af
1 changed files with 5 additions and 2 deletions
|
|
@ -437,7 +437,10 @@ class LowBitLinear(nn.Linear):
|
|||
self.compute_dtype = None # only for training
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.training:
|
||||
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||
# we should check both self.training and torch.is_inference_mode_enabled().
|
||||
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||
if is_training:
|
||||
# below logic is only for training
|
||||
autocast_dtype = get_autocast_dtype(x)
|
||||
if self.compute_dtype is not None and x.device.type == "xpu":
|
||||
|
|
@ -476,7 +479,7 @@ class LowBitLinear(nn.Linear):
|
|||
x_2d = x_2d.contiguous()
|
||||
|
||||
input_seq_size = x_shape[1]
|
||||
if self.training:
|
||||
if is_training:
|
||||
# training path
|
||||
if x_2d.requires_grad:
|
||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
||||
|
|
|
|||
Loading…
Reference in a new issue