[LLM] use pytorch linear for large input matrix (#8492)
* use pytorch linear for large input matrix * only works on server * fix style * optimize memory * first check server * revert * address comments * fix style
This commit is contained in:
parent
6504e31a97
commit
57e880f63a
3 changed files with 50 additions and 7 deletions
|
|
@ -978,6 +978,22 @@ _lib.ggml_qk_size.argtypes = [
|
||||||
_lib.ggml_qk_size.restype = ctypes.c_int
|
_lib.ggml_qk_size.restype = ctypes.c_int
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_dequantize_q4_0(
|
||||||
|
src: ctypes.c_void_p,
|
||||||
|
dst: ctypes.c_void_p,
|
||||||
|
k: ctypes.c_int,
|
||||||
|
):
|
||||||
|
_lib.ggml_dequantize_q4_0(src, dst, k)
|
||||||
|
|
||||||
|
|
||||||
|
_lib.ggml_dequantize_q4_0.argtypes = [
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_void_p,
|
||||||
|
ctypes.c_int,
|
||||||
|
]
|
||||||
|
_lib.ggml_quantize_q4_0.restype = None
|
||||||
|
|
||||||
|
|
||||||
def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64]
|
def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64]
|
||||||
src_0_data, # type: ctypes.c_void_p
|
src_0_data, # type: ctypes.c_void_p
|
||||||
src_0_qtype, # type: int
|
src_0_qtype, # type: int
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,14 @@ from torch import Tensor, device, dtype, nn
|
||||||
T = TypeVar("T", bound="torch.nn.Module")
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
import bigdl.llm.ggml.model.llama.llama_cpp as ggml
|
||||||
|
from bigdl.llm.utils.isa_checker import is_server
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import ctypes
|
import ctypes
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
IS_SERVER = is_server()
|
||||||
|
TORCH_LINEAR_THRESHOLD = 96
|
||||||
|
SYM_INT4 = ggml_tensor_qtype["sym_int4"]
|
||||||
|
|
||||||
|
|
||||||
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False):
|
def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=False):
|
||||||
|
|
@ -82,6 +87,19 @@ def ggml_convert_quant(tensor: torch.Tensor, qtype: int, convert_shape_only=Fals
|
||||||
return dst_tensor
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int):
|
||||||
|
invalidInputError(tensor.dtype == torch.uint8,
|
||||||
|
"Input tensor must be uint8")
|
||||||
|
src_ptr = ctypes.c_void_p(tensor.data.data_ptr())
|
||||||
|
|
||||||
|
dst_size = k
|
||||||
|
dst_tensor = torch.empty(weight_shape, dtype=torch.float)
|
||||||
|
dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr())
|
||||||
|
|
||||||
|
ggml.ggml_dequantize_q4_0(src_ptr, dst_ptr, k)
|
||||||
|
return dst_tensor
|
||||||
|
|
||||||
|
|
||||||
class ParamsQuant(torch.nn.Parameter):
|
class ParamsQuant(torch.nn.Parameter):
|
||||||
def __new__(cls,
|
def __new__(cls,
|
||||||
data=None,
|
data=None,
|
||||||
|
|
@ -193,6 +211,7 @@ class LinearQuant(nn.Linear):
|
||||||
self.in_len = input_features
|
self.in_len = input_features
|
||||||
self.out_len = output_features
|
self.out_len = output_features
|
||||||
self.weight_shape = (self.out_len, self.in_len)
|
self.weight_shape = (self.out_len, self.in_len)
|
||||||
|
self.weight_length = self.out_len * self.in_len
|
||||||
self.qtype = qtype
|
self.qtype = qtype
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
|
@ -201,15 +220,19 @@ class LinearQuant(nn.Linear):
|
||||||
self.bias.data = self.bias.data.to(x.dtype)
|
self.bias.data = self.bias.data.to(x.dtype)
|
||||||
|
|
||||||
x_shape = x.shape
|
x_shape = x.shape
|
||||||
x = x.view(-1, x_shape[-1])
|
x_2d = x.view(-1, x_shape[-1])
|
||||||
|
|
||||||
x0 = self.weight.data
|
x0 = self.weight.data
|
||||||
|
|
||||||
result = ggml_matmul_src1_x_src0_t(x0, x, self.weight_shape, self.qtype)
|
# todo may need to set a different number on different platforms
|
||||||
new_shape = x_shape[:-1] + (self.out_len,)
|
if IS_SERVER and self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
|
||||||
result = result.view(new_shape)
|
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)
|
||||||
|
result = F.linear(x, x0_fp32, self.bias)
|
||||||
if self.bias is not None:
|
else:
|
||||||
result += self.bias
|
result = ggml_matmul_src1_x_src0_t(x0, x_2d, self.weight_shape, self.qtype)
|
||||||
|
new_shape = x_shape[:-1] + (self.out_len,)
|
||||||
|
result = result.view(new_shape)
|
||||||
|
if self.bias is not None:
|
||||||
|
result += self.bias
|
||||||
|
|
||||||
return result.to(x.dtype)
|
return result.to(x.dtype)
|
||||||
|
|
|
||||||
|
|
@ -75,3 +75,7 @@ def check_avx512():
|
||||||
|
|
||||||
def check_avx512_vnni():
|
def check_avx512_vnni():
|
||||||
return isa_checker.check_avx512_vnni() and isa_checker.check_avx512()
|
return isa_checker.check_avx512_vnni() and isa_checker.check_avx512()
|
||||||
|
|
||||||
|
|
||||||
|
def is_server():
|
||||||
|
return check_avx512_vnni()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue