[LLM] Use fp32 as dtype when batch_size <=8 and qtype is q4_0/q8_0/fp8 (#9365)

This commit is contained in:
Yishuo Wang 2023-11-08 09:54:53 +08:00 committed by GitHub
parent 84ab614aab
commit bfd9f88f0d
2 changed files with 54 additions and 1 deletions

View file

@ -426,6 +426,7 @@ class LowBitLinear(nn.Linear):
try: try:
import intel_extension_for_pytorch import intel_extension_for_pytorch
import linear_q4_0 import linear_q4_0
from bigdl.llm.utils.xmx_checker import use_xmx
except ModuleNotFoundError: except ModuleNotFoundError:
invalidInputError(False, invalidInputError(False,
"Please `pip install bigdl_core_xe` first.") "Please `pip install bigdl_core_xe` first.")
@ -440,7 +441,8 @@ class LowBitLinear(nn.Linear):
# current workaround to reduce first token latency of fp32 input # current workaround to reduce first token latency of fp32 input
# sometimes fp16 cause nan and training instability # sometimes fp16 cause nan and training instability
# disable the conversion when training # disable the conversion when training
if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32: if self.conver_to_half and x_2d.shape[0] > 1 and x_2d.dtype == torch.float32 and \
not use_xmx(x_2d, self.weight.qtype):
x_2d = x_2d.half() x_2d = x_2d.half()
result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype, result = linear_q4_0.forward_new(x_2d, self.weight.data, self.weight.qtype,
input_seq_size) input_seq_size)

View file

@ -0,0 +1,51 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import intel_extension_for_pytorch as ipex
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
SYM_INT8 = ggml_tensor_qtype["sym_int8"]
NF4 = ggml_tensor_qtype["nf4"]
NF3 = ggml_tensor_qtype["nf3"]
FP8 = ggml_tensor_qtype["fp8"]
FP4 = ggml_tensor_qtype["fp4"]
MOFQ4 = ggml_tensor_qtype["mixed_4bit"]
class XMXChecker:
def __init__(self):
self.support_xmx = self.check_xmx()
self.supported_qtype = [SYM_INT4, SYM_INT8, FP8]
@staticmethod
def check_xmx():
name = torch.xpu.get_device_name(0)
# todo: not sure how to check xmx or how to get device name for now
return "Arc(TM)" in name or "GPU Max" in name or "GPU Flex" in name
def check(self, input_tensor: torch.Tensor, qtype: int):
return self.support_xmx and 1 < input_tensor.shape[0] <= 8 and \
qtype in self.supported_qtype
xmx_checker = XMXChecker()
def use_xmx(input_tensor: torch.Tensor, qtype: int):
return xmx_checker.check(input_tensor, qtype)