From 9f34da7cdba7d9ca0a51f370de6a93b688cbc3d4 Mon Sep 17 00:00:00 2001 From: Guancheng Fu <110874468+gc-fu@users.noreply.github.com> Date: Mon, 15 Jan 2024 15:42:15 +0800 Subject: [PATCH] Update PVC XMX condition (#9901) * update pvc xmx condition * update condition * update conditon --- python/llm/src/bigdl/llm/transformers/models/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 8e33b405..0191724a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -276,7 +276,10 @@ def use_xmx(x: torch.Tensor, qtype: int): device in ["arc", "flex", "pvc"] and qtype in [SYM_INT4, SYM_INT8, FP8] and ( + (device == "pvc" and 1 < x.size(0) <= 16) + or (device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64) - or 1 < x.size(0) <= 8 + or + 1 < x.size(0) <= 8 ) )