Disable fused layer norm when using XMX to fix mpt UT (#9933)

This commit is contained in:
Yishuo Wang 2024-01-18 16:22:12 +08:00 committed by GitHub
parent 1fc9dfa265
commit 7bbb98abb6
2 changed files with 14 additions and 1 deletions

View file

@ -37,6 +37,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import functional as F from torch.nn import functional as F
from bigdl.llm.transformers.models.utils import use_fused_layer_norm
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
@ -63,7 +64,7 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
def bloom_layer_norm_forward(self, hidden_states): def bloom_layer_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if use_fused_layer_norm(hidden_states, self.training):
import linear_q4_0 import linear_q4_0
result = linear_q4_0.fused_layer_norm(hidden_states, result = linear_q4_0.fused_layer_norm(hidden_states,
[self.weight.size(0)], [self.weight.size(0)],

View file

@ -283,3 +283,15 @@ def use_xmx(x: torch.Tensor, qtype: int):
1 < x.size(0) <= 8 1 < x.size(0) <= 8
) )
) )
def use_fused_layer_norm(x: torch.Tensor, training: bool):
return (
not training
and not x.requires_grad
and x.device.type == 'xpu'
and (
get_xpu_device_type(x) not in ["arc", "flex"]
or x.reshape(-1, x.size(-1)).size(0) == 1
)
)