From 6637860ddf05670f148926cbcad8935c1b741ffd Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 12 Jan 2024 19:51:48 +0800 Subject: [PATCH] change xmx condition (#9896) --- .../bigdl/llm/transformers/low_bit_linear.py | 2 +- .../bigdl/llm/transformers/models/utils.py | 17 ++++++ python/llm/src/bigdl/llm/utils/xmx_checker.py | 52 ------------------- 3 files changed, 18 insertions(+), 53 deletions(-) delete mode 100644 python/llm/src/bigdl/llm/utils/xmx_checker.py diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 1cb6b5c5..63543d20 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -470,7 +470,7 @@ class LowBitLinear(nn.Linear): try: import intel_extension_for_pytorch import linear_q4_0 - from bigdl.llm.utils.xmx_checker import use_xmx + from bigdl.llm.transformers.models.utils import use_xmx except ModuleNotFoundError: invalidInputError(False, "Please `pip install bigdl_core_xe` first.") diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 41994aec..8e33b405 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -21,6 +21,11 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type +SYM_INT4 = ggml_tensor_qtype["sym_int4"] +SYM_INT8 = ggml_tensor_qtype["sym_int8"] +FP8 = ggml_tensor_qtype["fp8"] + + def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): key_cache_storage = torch.empty(batch_size, num_heads, max_length, head_dim, @@ -263,3 +268,15 @@ def mlp_fusion_check(x, qtype, training): if training or x.requires_grad: return False return True + + +def use_xmx(x: torch.Tensor, qtype: int): + device = get_xpu_device_type(x) + return ( + device in ["arc", "flex", "pvc"] + and qtype in [SYM_INT4, SYM_INT8, FP8] + and ( + (device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64) + or 1 < x.size(0) <= 8 + ) + ) diff --git a/python/llm/src/bigdl/llm/utils/xmx_checker.py b/python/llm/src/bigdl/llm/utils/xmx_checker.py deleted file mode 100644 index 4bef2f9b..00000000 --- a/python/llm/src/bigdl/llm/utils/xmx_checker.py +++ /dev/null @@ -1,52 +0,0 @@ -# -# Copyright 2016 The BigDL Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import torch -import intel_extension_for_pytorch as ipex -from bigdl.llm.ggml.quantize import ggml_tensor_qtype - - -SYM_INT4 = ggml_tensor_qtype["sym_int4"] -SYM_INT8 = ggml_tensor_qtype["sym_int8"] -NF4 = ggml_tensor_qtype["nf4"] -NF3 = ggml_tensor_qtype["nf3"] -FP8 = ggml_tensor_qtype["fp8"] -FP4 = ggml_tensor_qtype["fp4"] -MOFQ4 = ggml_tensor_qtype["mixed_fp4"] -MOFQ8 = ggml_tensor_qtype["mixed_fp8"] - - -class XMXChecker: - def __init__(self): - self.support_xmx = self.check_xmx() - self.supported_qtype = [SYM_INT4, SYM_INT8, FP8] - - @staticmethod - def check_xmx(): - name = torch.xpu.get_device_name(0) - # todo: not sure how to check xmx or how to get device name for now - return "Arc(TM)" in name or "GPU Max" in name or "GPU Flex" in name - - def check(self, input_tensor: torch.Tensor, qtype: int): - return self.support_xmx and 1 < input_tensor.shape[0] <= 8 and \ - qtype in self.supported_qtype - - -xmx_checker = XMXChecker() - - -def use_xmx(input_tensor: torch.Tensor, qtype: int): - return xmx_checker.check(input_tensor, qtype)