[LLM] use pytorch linear for large input matrix (#8492)

* use pytorch linear for large input matrix

* only works on server

* fix style

* optimize memory

* first check server

* revert

* address comments

* fix style
This commit is contained in:
Yang Wang 2023-07-21 00:54:25 +08:00 committed by GitHub
parent 6504e31a97
commit 57e880f63a
3 changed files with 50 additions and 7 deletions

View file

@ -978,6 +978,22 @@ _lib.ggml_qk_size.argtypes = [
_lib.ggml_qk_size.restype = ctypes.c_int
def ggml_dequantize_q4_0(
src: ctypes.c_void_p,
dst: ctypes.c_void_p,
k: ctypes.c_int,
):
_lib.ggml_dequantize_q4_0(src, dst, k)
_lib.ggml_dequantize_q4_0.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
]
_lib.ggml_quantize_q4_0.restype = None
def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64]
src_0_data, # type: ctypes.c_void_p
src_0_qtype, # type: int

View file

@ -51,9 +51,14 @@ from torch import Tensor, device, dtype, nn
T = TypeVar("T", bound="torch.nn.Module")
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
from bigdl.llm.utils.isa_checker import is_server
import torch
import ctypes
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
IS_SERVER = is_server()
TORCH_LINEAR_THRESHOLD = 96
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False):
@ -82,6 +87,19 @@ def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=Fals
return dst_tensor
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
invalidInputError(tensor.dtype == torch.uint8,
"Input tensor must be uint8")
src_ptr = ctypes.c_void_p(tensor.data.data_ptr())
dst_size = k
dst_tensor = torch.empty(weight_shape, dtype=torch.float)
dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr())
ggml.ggml_dequantize_q4_0(src_ptr, dst_ptr, k)
return dst_tensor
class ParamsQuant(torch.nn.Parameter):
def __new__(cls,
data=None,
@ -193,6 +211,7 @@ class LinearQuant(nn.Linear):
self.in_len = input_features
self.out_len = output_features
self.weight_shape = (self.out_len, self.in_len)
self.weight_length = self.out_len * self.in_len
self.qtype = qtype
def forward(self, x: torch.Tensor):
@ -201,15 +220,19 @@ class LinearQuant(nn.Linear):
self.bias.data = self.bias.data.to(x.dtype)
x_shape = x.shape
x = x.view(-1, x_shape[-1])
x_2d = x.view(-1, x_shape[-1])
x0 = self.weight.data
result = ggml_matmul_src1_x_src0_t(x0, x, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.bias is not None:
result += self.bias
# todo may need to set a different number on different platforms
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
result = F.linear(x, x0_fp32, self.bias)
else:
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
new_shape = x_shape[:-1] + (self.out_len,)
result = result.view(new_shape)
if self.bias is not None:
result += self.bias
return result.to(x.dtype)

View file

@ -75,3 +75,7 @@ def check_avx512():
def check_avx512_vnni():
return isa_checker.check_avx512_vnni() and isa_checker.check_avx512()
def is_server():
return check_avx512_vnni()