disable fused layer norm on UHD (#10130)

This commit is contained in:
Yishuo Wang 2024-02-08 10:20:01 +08:00 committed by GitHub
parent a8450fc300
commit 1aa0c623ce

View file

@ -309,12 +309,13 @@ def use_xmx(x: torch.Tensor, qtype: int):
def use_fused_layer_norm(x: torch.Tensor, training: bool):
device = get_xpu_device_type(x)
return (
not training
and not x.requires_grad
and x.device.type == 'xpu'
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
and (
get_xpu_device_type(x) not in ["arc", "flex"]
device == "mtl" # fused layer norm conflicts with XMX, so disable it when using XMX
or x.numel() // x.size(-1) == 1
)
)