[LLM] use pytorch linear for large input matrix (#8492)
* use pytorch linear for large input matrix * only works on server * fix style * optimize memory * first check server * revert * address comments * fix style
This commit is contained in:
parent
6504e31a97
commit
57e880f63a
3 changed files with 50 additions and 7 deletions
|
|
@ -978,6 +978,22 @@ _lib.ggml_qk_size.argtypes = [
|
|||
_lib.ggml_qk_size.restype = ctypes.c_int
|
||||
|
||||
|
||||
def ggml_dequantize_q4_0(
|
||||
src: ctypes.c_void_p,
|
||||
dst: ctypes.c_void_p,
|
||||
k: ctypes.c_int,
|
||||
):
|
||||
_lib.ggml_dequantize_q4_0(src, dst, k)
|
||||
|
||||
|
||||
_lib.ggml_dequantize_q4_0.argtypes = [
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_void_p,
|
||||
ctypes.c_int,
|
||||
]
|
||||
_lib.ggml_quantize_q4_0.restype = None
|
||||
|
||||
|
||||
def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64]
|
||||
src_0_data, # type: ctypes.c_void_p
|
||||
src_0_qtype, # type: int
|
||||
|
|
|
|||
|
|
@ -51,9 +51,14 @@ from torch import Tensor, device, dtype, nn
|
|||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
||||
from bigdl.llm.utils.isa_checker import is_server
|
||||
|
||||
import torch
|
||||
import ctypes
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
IS_SERVER = is_server()
|
||||
TORCH_LINEAR_THRESHOLD = 96
|
||||
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||
|
||||
|
||||
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False):
|
||||
|
|
@ -82,6 +87,19 @@ def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=Fals
|
|||
return dst_tensor
|
||||
|
||||
|
||||
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
|
||||
invalidInputError(tensor.dtype == torch.uint8,
|
||||
"Input tensor must be uint8")
|
||||
src_ptr = ctypes.c_void_p(tensor.data.data_ptr())
|
||||
|
||||
dst_size = k
|
||||
dst_tensor = torch.empty(weight_shape, dtype=torch.float)
|
||||
dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||
|
||||
ggml.ggml_dequantize_q4_0(src_ptr, dst_ptr, k)
|
||||
return dst_tensor
|
||||
|
||||
|
||||
class ParamsQuant(torch.nn.Parameter):
|
||||
def __new__(cls,
|
||||
data=None,
|
||||
|
|
@ -193,6 +211,7 @@ class LinearQuant(nn.Linear):
|
|||
self.in_len = input_features
|
||||
self.out_len = output_features
|
||||
self.weight_shape = (self.out_len, self.in_len)
|
||||
self.weight_length = self.out_len * self.in_len
|
||||
self.qtype = qtype
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
|
@ -201,15 +220,19 @@ class LinearQuant(nn.Linear):
|
|||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
x_shape = x.shape
|
||||
x = x.view(-1, x_shape[-1])
|
||||
x_2d = x.view(-1, x_shape[-1])
|
||||
|
||||
x0 = self.weight.data
|
||||
|
||||
result = ggml_matmul_src1_x_src0_t(x0, x, self.weight_shape, self.qtype)
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
# todo may need to set a different number on different platforms
|
||||
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
||||
result = F.linear(x, x0_fp32, self.bias)
|
||||
else:
|
||||
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
||||
new_shape = x_shape[:-1] + (self.out_len,)
|
||||
result = result.view(new_shape)
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
|
||||
return result.to(x.dtype)
|
||||
|
|
|
|||
|
|
@ -75,3 +75,7 @@ def check_avx512():
|
|||
|
||||
def check_avx512_vnni():
|
||||
return isa_checker.check_avx512_vnni() and isa_checker.check_avx512()
|
||||
|
||||
|
||||
def is_server():
|
||||
return check_avx512_vnni()
|
||||
|
|
|
|||
Loading…
Reference in a new issue