change xmx condition (#10000)
This commit is contained in:
parent
8b08ad408b
commit
24b34b6e46
1 changed files with 3 additions and 4 deletions
|
|
@ -291,14 +291,13 @@ def mlp_fusion_check(x, qtype, training):
|
||||||
def use_xmx(x: torch.Tensor, qtype: int):
|
def use_xmx(x: torch.Tensor, qtype: int):
|
||||||
device = get_xpu_device_type(x)
|
device = get_xpu_device_type(x)
|
||||||
return (
|
return (
|
||||||
device in ["arc", "flex", "pvc"]
|
os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1"
|
||||||
|
and device in ["arc", "flex", "pvc"]
|
||||||
and qtype in [SYM_INT4, SYM_INT8, FP8]
|
and qtype in [SYM_INT4, SYM_INT8, FP8]
|
||||||
and (
|
and (
|
||||||
(device == "pvc" and 1 < x.size(0) <= 16)
|
(device == "pvc" and 1 < x.size(0) <= 16)
|
||||||
or
|
or
|
||||||
(device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64)
|
(device != "pvc" and 1 < x.size(0) <= 64)
|
||||||
or
|
|
||||||
1 < x.size(0) <= 8
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue