Add qlora nf4 (#8782)

* add nf4

* dequant nf4

* style
This commit is contained in:
Kai Huang 2023-09-06 09:39:22 +08:00 committed by GitHub
parent 14b95ebfb4
commit 4a9ff050a1
3 changed files with 20 additions and 1 deletions

View file

@ -999,6 +999,22 @@ _lib.ggml_dequantize_q4_0.argtypes = [
_lib.ggml_quantize_q4_0.restype = None _lib.ggml_quantize_q4_0.restype = None
# def ggml_dequantize_nf4(
# src: ctypes.c_void_p,
# dst: ctypes.c_void_p,
# k: ctypes.c_int,
# ):
# _lib.ggml_dequantize_nf4(src, dst, k)
#
#
# _lib.ggml_dequantize_nf4.argtypes = [
# ctypes.c_void_p,
# ctypes.c_void_p,
# ctypes.c_int,
# ]
# _lib.ggml_dequantize_nf4.restype = None
def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64] def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64]
src_0_data, # type: ctypes.c_void_p src_0_data, # type: ctypes.c_void_p
src_0_qtype, # type: int src_0_qtype, # type: int

View file

@ -29,7 +29,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml
"asym_int4": 3, # q4_1 in ggml "asym_int4": 3, # q4_1 in ggml
"sym_int5": 6, # q5_0 in ggml "sym_int5": 6, # q5_0 in ggml
"asym_int5": 7, # q5_1 in ggml "asym_int5": 7, # q5_1 in ggml
"sym_int8": 8} # q8_0 in ggml "sym_int8": 8, # q8_0 in ggml
"nf4": 10}
_llama_quantize_type = {"q4_0": 2, _llama_quantize_type = {"q4_0": 2,
"q4_1": 3, "q4_1": 3,

View file

@ -257,6 +257,8 @@ class LowBitLinear(nn.Linear):
else: else:
# CPU logic # CPU logic
# todo may need to set a different number on different platforms # todo may need to set a different number on different platforms
invalidInputError(self.qtype != ggml_tensor_qtype["nf4"],
"NF4 quantization is currently not supported on CPU")
if IS_SERVER and (not IS_SPR) and \ if IS_SERVER and (not IS_SPR) and \
self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD:
x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)