disable fused layer norm on UHD (#10130)
This commit is contained in:
parent
a8450fc300
commit
1aa0c623ce
1 changed files with 3 additions and 2 deletions
|
|
@ -309,12 +309,13 @@ def use_xmx(x: torch.Tensor, qtype: int):
|
|||
|
||||
|
||||
def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||
device = get_xpu_device_type(x)
|
||||
return (
|
||||
not training
|
||||
and not x.requires_grad
|
||||
and x.device.type == 'xpu'
|
||||
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
|
||||
and (
|
||||
get_xpu_device_type(x) not in ["arc", "flex"]
|
||||
device == "mtl" # fused layer norm conflicts with XMX, so disable it when using XMX
|
||||
or x.numel() // x.size(-1) == 1
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue