Disable fused layer norm when using XMX to fix mpt UT (#9933)
This commit is contained in:
parent
1fc9dfa265
commit
7bbb98abb6
2 changed files with 14 additions and 1 deletions
|
|
@ -37,6 +37,7 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
from bigdl.llm.transformers.models.utils import use_fused_layer_norm
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,7 +64,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
||||||
|
|
||||||
|
|
||||||
def bloom_layer_norm_forward(self, hidden_states):
|
def bloom_layer_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if use_fused_layer_norm(hidden_states, self.training):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
result = linear_q4_0.fused_layer_norm(hidden_states,
|
result = linear_q4_0.fused_layer_norm(hidden_states,
|
||||||
[self.weight.size(0)],
|
[self.weight.size(0)],
|
||||||
|
|
|
||||||
|
|
@ -283,3 +283,15 @@ def use_xmx(x: torch.Tensor, qtype: int):
|
||||||
1 < x.size(0) <= 8
|
1 < x.size(0) <= 8
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||||
|
return (
|
||||||
|
not training
|
||||||
|
and not x.requires_grad
|
||||||
|
and x.device.type == 'xpu'
|
||||||
|
and (
|
||||||
|
get_xpu_device_type(x) not in ["arc", "flex"]
|
||||||
|
or x.reshape(-1, x.size(-1)).size(0) == 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue