change xmx condition (#9896)
This commit is contained in:
parent
0e69bfe6b0
commit
6637860ddf
3 changed files with 18 additions and 53 deletions
|
|
@ -470,7 +470,7 @@ class LowBitLinear(nn.Linear):
|
|||
try:
|
||||
import intel_extension_for_pytorch
|
||||
import linear_q4_0
|
||||
from bigdl.llm.utils.xmx_checker import use_xmx
|
||||
from bigdl.llm.transformers.models.utils import use_xmx
|
||||
except ModuleNotFoundError:
|
||||
invalidInputError(False,
|
||||
"Please `pip install bigdl_core_xe` first.")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,11 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
|||
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||
|
||||
|
||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
||||
FP8 = ggml_tensor_qtype["fp8"]
|
||||
|
||||
|
||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||
max_length, head_dim,
|
||||
|
|
@ -263,3 +268,15 @@ def mlp_fusion_check(x, qtype, training):
|
|||
if training or x.requires_grad:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def use_xmx(x: torch.Tensor, qtype: int):
|
||||
device = get_xpu_device_type(x)
|
||||
return (
|
||||
device in ["arc", "flex", "pvc"]
|
||||
and qtype in [SYM_INT4, SYM_INT8, FP8]
|
||||
and (
|
||||
(device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64)
|
||||
or 1 < x.size(0) <= 8
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,52 +0,0 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
||||
|
||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
||||
NF4 = ggml_tensor_qtype["nf4"]
|
||||
NF3 = ggml_tensor_qtype["nf3"]
|
||||
FP8 = ggml_tensor_qtype["fp8"]
|
||||
FP4 = ggml_tensor_qtype["fp4"]
|
||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
||||
|
||||
|
||||
class XMXChecker:
|
||||
def __init__(self):
|
||||
self.support_xmx = self.check_xmx()
|
||||
self.supported_qtype = [SYM_INT4, SYM_INT8, FP8]
|
||||
|
||||
@staticmethod
|
||||
def check_xmx():
|
||||
name = torch.xpu.get_device_name(0)
|
||||
# todo: not sure how to check xmx or how to get device name for now
|
||||
return "Arc(TM)" in name or "GPU Max" in name or "GPU Flex" in name
|
||||
|
||||
def check(self, input_tensor: torch.Tensor, qtype: int):
|
||||
return self.support_xmx and 1 < input_tensor.shape[0] <= 8 and \
|
||||
qtype in self.supported_qtype
|
||||
|
||||
|
||||
xmx_checker = XMXChecker()
|
||||
|
||||
|
||||
def use_xmx(input_tensor: torch.Tensor, qtype: int):
|
||||
return xmx_checker.check(input_tensor, qtype)
|
||||
Loading…
Reference in a new issue