fix baichuan2-7b 1st token performance regression on xpu (#9683)
* fix baichuan2-7b 1st token performance regression * add comments * fix style
This commit is contained in:
parent
877229f3be
commit
5e46e0e5af
1 changed files with 5 additions and 2 deletions
|
|
@ -437,7 +437,10 @@ class LowBitLinear(nn.Linear):
|
||||||
self.compute_dtype = None # only for training
|
self.compute_dtype = None # only for training
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
if self.training:
|
# Due to inconsistent training status in some models like Baichuan-7b-Chat,
|
||||||
|
# we should check both self.training and torch.is_inference_mode_enabled().
|
||||||
|
is_training = self.training and not torch.is_inference_mode_enabled()
|
||||||
|
if is_training:
|
||||||
# below logic is only for training
|
# below logic is only for training
|
||||||
autocast_dtype = get_autocast_dtype(x)
|
autocast_dtype = get_autocast_dtype(x)
|
||||||
if self.compute_dtype is not None and x.device.type == "xpu":
|
if self.compute_dtype is not None and x.device.type == "xpu":
|
||||||
|
|
@ -476,7 +479,7 @@ class LowBitLinear(nn.Linear):
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
|
|
||||||
input_seq_size = x_shape[1]
|
input_seq_size = x_shape[1]
|
||||||
if self.training:
|
if is_training:
|
||||||
# training path
|
# training path
|
||||||
if x_2d.requires_grad:
|
if x_2d.requires_grad:
|
||||||
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
result = MatMulLowBit.apply(x_2d, self.weight, input_seq_size)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue