disable fused layer norm on UHD (#10130)
This commit is contained in:
parent
a8450fc300
commit
1aa0c623ce
1 changed files with 3 additions and 2 deletions
|
|
@ -309,12 +309,13 @@ def use_xmx(x: torch.Tensor, qtype: int):
|
||||||
|
|
||||||
|
|
||||||
def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
def use_fused_layer_norm(x: torch.Tensor, training: bool):
|
||||||
|
device = get_xpu_device_type(x)
|
||||||
return (
|
return (
|
||||||
not training
|
not training
|
||||||
and not x.requires_grad
|
and not x.requires_grad
|
||||||
and x.device.type == 'xpu'
|
and device in ["arc", "flex", "pvc", "mtl"] # fused layer norm cannot run on UHD
|
||||||
and (
|
and (
|
||||||
get_xpu_device_type(x) not in ["arc", "flex"]
|
device == "mtl" # fused layer norm conflicts with XMX, so disable it when using XMX
|
||||||
or x.numel() // x.size(-1) == 1
|
or x.numel() // x.size(-1) == 1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue