diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py index 979e83a0..ed62e5f7 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py @@ -999,6 +999,22 @@ _lib.ggml_dequantize_q4_0.argtypes = [ _lib.ggml_quantize_q4_0.restype = None +# def ggml_dequantize_nf4( +# src: ctypes.c_void_p, +# dst: ctypes.c_void_p, +# k: ctypes.c_int, +# ): +# _lib.ggml_dequantize_nf4(src, dst, k) +# +# +# _lib.ggml_dequantize_nf4.argtypes = [ +# ctypes.c_void_p, +# ctypes.c_void_p, +# ctypes.c_int, +# ] +# _lib.ggml_dequantize_nf4.restype = None + + def ggml_compute_forward_mul_mat_q_fp32(src_0_ne, # type: ctypes.Array[ctypes.c_int64] src_0_data, # type: ctypes.c_void_p src_0_qtype, # type: int diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index c61a05fe..2a45f642 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -29,7 +29,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "asym_int4": 3, # q4_1 in ggml "sym_int5": 6, # q5_0 in ggml "asym_int5": 7, # q5_1 in ggml - "sym_int8": 8} # q8_0 in ggml + "sym_int8": 8, # q8_0 in ggml + "nf4": 10} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 6e6d328c..0668c7c9 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -257,6 +257,8 @@ class LowBitLinear(nn.Linear): else: # CPU logic # todo may need to set a different number on different platforms + invalidInputError(self.qtype != ggml_tensor_qtype["nf4"], + "NF4 quantization is currently not supported on CPU") if IS_SERVER and (not IS_SPR) and \ self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)