change xmx condition (#9896)

This commit is contained in:
Yishuo Wang 2024-01-12 19:51:48 +08:00 committed by GitHub
parent 0e69bfe6b0
commit 6637860ddf
3 changed files with 18 additions and 53 deletions

View file

@ -470,7 +470,7 @@ class LowBitLinear(nn.Linear):
try:
import intel_extension_for_pytorch
import linear_q4_0
from bigdl.llm.utils.xmx_checker import use_xmx
from bigdl.llm.transformers.models.utils import use_xmx
except ModuleNotFoundError:
invalidInputError(False,
"Please `pip install bigdl_core_xe` first.")

View file

@ -21,6 +21,11 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
FP8 = ggml_tensor_qtype["fp8"]
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
key_cache_storage = torch.empty(batch_size, num_heads,
max_length, head_dim,
@ -263,3 +268,15 @@ def mlp_fusion_check(x, qtype, training):
if training or x.requires_grad:
return False
return True
def use_xmx(x: torch.Tensor, qtype: int):
device = get_xpu_device_type(x)
return (
device in ["arc", "flex", "pvc"]
and qtype in [SYM_INT4, SYM_INT8, FP8]
and (
(device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64)
or 1 < x.size(0) <= 8
)
)

View file

@ -1,52 +0,0 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import intel_extension_for_pytorch as ipex
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"]
FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
class XMXChecker:
def __init__(self):
self.support_xmx = self.check_xmx()
self.supported_qtype = [SYM_INT4, SYM_INT8, FP8]
@staticmethod
def check_xmx():
name = torch.xpu.get_device_name(0)
# todo: not sure how to check xmx or how to get device name for now
return "Arc(TM)" in name or "GPU Max" in name or "GPU Flex" in name
def check(self, input_tensor: torch.Tensor, qtype: int):
return self.support_xmx and 1 < input_tensor.shape[0] <= 8 and \
qtype in self.supported_qtype
xmx_checker = XMXChecker()
def use_xmx(input_tensor: torch.Tensor, qtype: int):
return xmx_checker.check(input_tensor, qtype)