change xmx condition (#10000)

This commit is contained in:
Yishuo Wang 2024-01-25 17:48:11 +08:00 committed by GitHub
parent 8b08ad408b
commit 24b34b6e46

View file

@ -291,14 +291,13 @@ def mlp_fusion_check(x, qtype, training):
def use_xmx(x: torch.Tensor, qtype: int): def use_xmx(x: torch.Tensor, qtype: int):
device = get_xpu_device_type(x) device = get_xpu_device_type(x)
return ( return (
device in ["arc", "flex", "pvc"] os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1"
and device in ["arc", "flex", "pvc"]
and qtype in [SYM_INT4, SYM_INT8, FP8] and qtype in [SYM_INT4, SYM_INT8, FP8]
and ( and (
(device == "pvc" and 1 < x.size(0) <= 16) (device == "pvc" and 1 < x.size(0) <= 16)
or or
(device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64) (device != "pvc" and 1 < x.size(0) <= 64)
or
1 < x.size(0) <= 8
) )
) )