diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 2a45f642..7023a4bd 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -30,7 +30,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "sym_int5": 6, # q5_0 in ggml "asym_int5": 7, # q5_1 in ggml "sym_int8": 8, # q8_0 in ggml - "nf4": 10} + "nf4": 10, + "nf3": 11} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 89b2e7fb..931a118f 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -62,6 +62,7 @@ TORCH_LINEAR_THRESHOLD = 96 SYM_INT4 = ggml_tensor_qtype["sym_int4"] SYM_INT8 = ggml_tensor_qtype["sym_int8"] NF4 = ggml_tensor_qtype["nf4"] +NF3 = ggml_tensor_qtype["nf3"] def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, @@ -101,7 +102,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int src = ctypes.c_void_p(tensor.data.data_ptr()) - if qtype in [SYM_INT4, SYM_INT8, NF4]: + if qtype in [SYM_INT4, SYM_INT8, NF4, NF3]: dst_tensor = torch.empty_like(tensor) elif qtype == ggml_tensor_qtype["sym_int5"]: QK = ggml.ggml_qk_size(qtype) @@ -126,7 +127,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int src = ctypes.c_void_p(tensor.data.data_ptr()) - if qtype in [SYM_INT4, SYM_INT8, NF4]: + if qtype in [SYM_INT4, SYM_INT8, NF4, NF3]: dst_tensor = torch.empty_like(tensor) elif qtype == ggml_tensor_qtype["sym_int5"]: QK = ggml.ggml_qk_size(ggml_tensor_qtype["asym_int5"]) @@ -363,8 +364,8 @@ class LowBitLinear(nn.Linear): else: # CPU logic # todo may need to set a different number on different platforms - invalidInputError(self.qtype != ggml_tensor_qtype["nf4"], - "NF4 quantization is currently not supported on CPU") + invalidInputError(self.qtype != NF3 and self.qtype != NF4, + "NF3 and NF4 quantization are currently not supported on CPU") if IS_SERVER and (not IS_SPR) and \ self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)