From 24b34b6e46c8d3bd519f45db0c92c1f24b4f19f4 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 25 Jan 2024 17:48:11 +0800 Subject: [PATCH] change xmx condition (#10000) --- python/llm/src/bigdl/llm/transformers/models/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 2b910cdc..bb0682c8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -291,14 +291,13 @@ def mlp_fusion_check(x, qtype, training): def use_xmx(x: torch.Tensor, qtype: int): device = get_xpu_device_type(x) return ( - device in ["arc", "flex", "pvc"] + os.environ.get("BIGDL_LLM_XMX_DISABLED", "0") != "1" + and device in ["arc", "flex", "pvc"] and qtype in [SYM_INT4, SYM_INT8, FP8] and ( (device == "pvc" and 1 < x.size(0) <= 16) or - (device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64) - or - 1 < x.size(0) <= 8 + (device != "pvc" and 1 < x.size(0) <= 64) ) )