diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py index 68233c0f..f7e84321 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py @@ -978,6 +978,22 @@ _lib.ggml_qk_size.argtypes = [ _lib.ggml_qk_size.restype = ctypes.c_int +def ggml_dequantize_q4_0( + src: ctypes.c_void_p, + dst: ctypes.c_void_p, + k: ctypes.c_int, +): + _lib.ggml_dequantize_q4_0(src, dst, k) + + +_lib.ggml_dequantize_q4_0.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, +] +_lib.ggml_quantize_q4_0.restype = None + + def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64] src_0_data, # type: ctypes.c_void_p src_0_qtype, # type: int diff --git a/python/llm/src/bigdl/llm/transformers/linear_quant.py b/python/llm/src/bigdl/llm/transformers/linear_quant.py index 622e7ea6..19f4e6c1 100644 --- a/python/llm/src/bigdl/llm/transformers/linear_quant.py +++ b/python/llm/src/bigdl/llm/transformers/linear_quant.py @@ -51,9 +51,14 @@ from torch import Tensor, device, dtype, nn T = TypeVar("T", bound="torch.nn.Module") import bigdl.llm.ggml.model.llama.llama_cpp as ggml +from bigdl.llm.utils.isa_checker import is_server import torch import ctypes +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +IS_SERVER = is_server() +TORCH_LINEAR_THRESHOLD = 96 +SYM_INT4 = ggml_tensor_qtype["sym_int4"] def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False): @@ -82,6 +87,19 @@ def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=Fals return dst_tensor +def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int): + invalidInputError(tensor.dtype == torch.uint8, + "Input tensor must be uint8") + src_ptr = ctypes.c_void_p(tensor.data.data_ptr()) + + dst_size = k + dst_tensor = torch.empty(weight_shape, dtype=torch.float) + dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr()) + + ggml.ggml_dequantize_q4_0(src_ptr, dst_ptr, k) + return dst_tensor + + class ParamsQuant(torch.nn.Parameter): def __new__(cls, data=None, @@ -193,6 +211,7 @@ class LinearQuant(nn.Linear): self.in_len = input_features self.out_len = output_features self.weight_shape = (self.out_len, self.in_len) + self.weight_length = self.out_len * self.in_len self.qtype = qtype def forward(self, x: torch.Tensor): @@ -201,15 +220,19 @@ class LinearQuant(nn.Linear): self.bias.data = self.bias.data.to(x.dtype) x_shape = x.shape - x = x.view(-1, x_shape[-1]) + x_2d = x.view(-1, x_shape[-1]) x0 = self.weight.data - result = ggml_matmul_src1_x_src0_t(x0, x, self.weight_shape, self.qtype) - new_shape = x_shape[:-1] + (self.out_len,) - result = result.view(new_shape) - - if self.bias is not None: - result += self.bias + # todo may need to set a different number on different platforms + if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: + x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) + result = F.linear(x, x0_fp32, self.bias) + else: + result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype) + new_shape = x_shape[:-1] + (self.out_len,) + result = result.view(new_shape) + if self.bias is not None: + result += self.bias return result.to(x.dtype) diff --git a/python/llm/src/bigdl/llm/utils/isa_checker.py b/python/llm/src/bigdl/llm/utils/isa_checker.py index 397ebed9..0118506b 100644 --- a/python/llm/src/bigdl/llm/utils/isa_checker.py +++ b/python/llm/src/bigdl/llm/utils/isa_checker.py @@ -75,3 +75,7 @@ def check_avx512(): def check_avx512_vnni(): return isa_checker.check_avx512_vnni() and isa_checker.check_avx512() + + +def is_server(): + return check_avx512_vnni()