Disable fused layer norm when using XMX to fix mpt UT (#9933)
This commit is contained in:
parent
1fc9dfa265
commit
7bbb98abb6
2 changed files with 14 additions and 1 deletions
|
|
@ -37,6 +37,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import functional as F
|
||||
from bigdl.llm.transformers.models.utils import use_fused_layer_norm
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
|
||||
|
||||
|
|
@ -63,7 +64,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
|||
|
||||
|
||||
def bloom_layer_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if use_fused_layer_norm(hidden_states, self.training):
|
||||
import linear_q4_0
|
||||
result = linear_q4_0.fused_layer_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
|
|
|
|||
|
|
@ -283,3 +283,15 @@ def use_xmx(x: torch.Tensor, qtype: int):
|
|||
1 < x.size(0) <= 8
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||
return (
|
||||
not training
|
||||
and not x.requires_grad
|
||||
and x.device.type == 'xpu'
|
||||
and (
|
||||
get_xpu_device_type(x) not in ["arc", "flex"]
|
||||
or x.reshape(-1, x.size(-1)).size(0) == 1
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue