change xmx condition (#9896)
This commit is contained in:
parent
0e69bfe6b0
commit
6637860ddf
3 changed files with 18 additions and 53 deletions
|
|
@ -470,7 +470,7 @@ class LowBitLinear(nn.Linear):
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch
|
import intel_extension_for_pytorch
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
from bigdl.llm.utils.xmx_checker import use_xmx
|
from bigdl.llm.transformers.models.utils import use_xmx
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
invalidInputError(False,
|
invalidInputError(False,
|
||||||
"Please `pip install bigdl_core_xe` first.")
|
"Please `pip install bigdl_core_xe` first.")
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,11 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
from bigdl.llm.transformers.utils import get_ipex_version, get_xpu_device_type
|
||||||
|
|
||||||
|
|
||||||
|
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||||
|
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
||||||
|
FP8 = ggml_tensor_qtype["fp8"]
|
||||||
|
|
||||||
|
|
||||||
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device):
|
||||||
key_cache_storage = torch.empty(batch_size, num_heads,
|
key_cache_storage = torch.empty(batch_size, num_heads,
|
||||||
max_length, head_dim,
|
max_length, head_dim,
|
||||||
|
|
@ -263,3 +268,15 @@ def mlp_fusion_check(x, qtype, training):
|
||||||
if training or x.requires_grad:
|
if training or x.requires_grad:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def use_xmx(x: torch.Tensor, qtype: int):
|
||||||
|
device = get_xpu_device_type(x)
|
||||||
|
return (
|
||||||
|
device in ["arc", "flex", "pvc"]
|
||||||
|
and qtype in [SYM_INT4, SYM_INT8, FP8]
|
||||||
|
and (
|
||||||
|
(device != "pvc" and x.dtype == torch.float32 and 1 < x.size(0) <= 64)
|
||||||
|
or 1 < x.size(0) <= 8
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,52 +0,0 @@
|
||||||
#
|
|
||||||
# Copyright 2016 The BigDL Authors.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
|
||||||
|
|
||||||
|
|
||||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
|
||||||
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
|
|
||||||
NF4 = ggml_tensor_qtype["nf4"]
|
|
||||||
NF3 = ggml_tensor_qtype["nf3"]
|
|
||||||
FP8 = ggml_tensor_qtype["fp8"]
|
|
||||||
FP4 = ggml_tensor_qtype["fp4"]
|
|
||||||
MOFQ4 = ggml_tensor_qtype["mixed_fp4"]
|
|
||||||
MOFQ8 = ggml_tensor_qtype["mixed_fp8"]
|
|
||||||
|
|
||||||
|
|
||||||
class XMXChecker:
|
|
||||||
def __init__(self):
|
|
||||||
self.support_xmx = self.check_xmx()
|
|
||||||
self.supported_qtype = [SYM_INT4, SYM_INT8, FP8]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def check_xmx():
|
|
||||||
name = torch.xpu.get_device_name(0)
|
|
||||||
# todo: not sure how to check xmx or how to get device name for now
|
|
||||||
return "Arc(TM)" in name or "GPU Max" in name or "GPU Flex" in name
|
|
||||||
|
|
||||||
def check(self, input_tensor: torch.Tensor, qtype: int):
|
|
||||||
return self.support_xmx and 1 < input_tensor.shape[0] <= 8 and \
|
|
||||||
qtype in self.supported_qtype
|
|
||||||
|
|
||||||
|
|
||||||
xmx_checker = XMXChecker()
|
|
||||||
|
|
||||||
|
|
||||||
def use_xmx(input_tensor: torch.Tensor, qtype: int):
|
|
||||||
return xmx_checker.check(input_tensor, qtype)
|
|
||||||
Loading…
Reference in a new issue