diff --git a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py index e95278e4..78f5681b 100644 --- a/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py +++ b/python/llm/src/bigdl/llm/ggml/model/llama/llama_cpp.py @@ -999,6 +999,24 @@ _lib.ggml_dequantize_q4_0.argtypes = [ _lib.ggml_quantize_q4_0.restype = None +def ggml_dequantize( + src: ctypes.c_void_p, + dst: ctypes.c_void_p, + k: ctypes.c_size_t, + qtype: ctypes.c_int +): + _lib.ggml_dequantize(src, dst, k, qtype) + + +_lib.ggml_dequantize.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int +] +_lib.ggml_dequantize.restype = None + + def ggml_q_format_convet_cpu2xpu( src: ctypes.c_void_p, dst: ctypes.c_void_p, diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 4f8891e9..38904d58 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -34,7 +34,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "nf3": 11, "fp16": 12, "fp8": 15, - "fp4": 16} + "fp4": 16, + "mixed_4bit": 17} # Mixture of Formats Quantization 4 bits _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 7d3266aa..9676b0c7 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -66,6 +66,7 @@ NF4 = ggml_tensor_qtype["nf4"] NF3 = ggml_tensor_qtype["nf3"] FP8 = ggml_tensor_qtype["fp8"] FP4 = ggml_tensor_qtype["fp4"] +MOFQ4 = ggml_tensor_qtype["mixed_4bit"] def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, @@ -158,6 +159,19 @@ def ggml_int4_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int): return dst_tensor +def ggml_convert_fp32(tensor: torch.Tensor, weight_shape: tuple, k: int, qtype: int): + invalidInputError(tensor.dtype == torch.uint8, + "Input tensor must be uint8") + src_ptr = ctypes.c_void_p(tensor.data.data_ptr()) + + dst_size = k + dst_tensor = torch.empty(weight_shape, dtype=torch.float) + dst_ptr = ctypes.c_void_p(dst_tensor.data.data_ptr()) + + ggml.ggml_dequantize(src_ptr, dst_ptr, k, qtype) + return dst_tensor + + # Rename to FP4Params to trigger initializing # the params layer with all parameters on the CPU # https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/modeling.py#L333 @@ -183,10 +197,39 @@ class FP4Params(torch.nn.Parameter): def quantize(self, device=None): if not self.quantized: w = self.data.contiguous().float() - w_quantized = ggml_convert_qtype(w, self.qtype, - device=device, - convert_shape_only=self.convert_shape_only) - self.data = w_quantized + if self.qtype == MOFQ4: + if device == 'meta': + w_quantized = ggml_convert_qtype(w, SYM_INT4, + device=device, + convert_shape_only=self.convert_shape_only) + # TODO: should load from config, the current implementation doesn't support + # save/load + self.qtype = SYM_INT4 + else: + from torch.nn.functional import mse_loss + w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4, + device=device, + convert_shape_only=self.convert_shape_only) + w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape, + reduce(mul, w.shape, 1), SYM_INT4) + w_quant_fp4 = ggml_convert_qtype(w, FP4, + device=device, + convert_shape_only=self.convert_shape_only) + w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape, + reduce(mul, w.shape, 1), FP4) + q4_0_mse = mse_loss(w_q4_0_dequant, w) + fp4_mse = mse_loss(w_fp4_dequant, w) + if q4_0_mse <= fp4_mse: + self.qtype = SYM_INT4 + self.data = w_quant_q4_0 + else: + self.qtype = FP4 + self.data = w_quant_fp4 + else: + w_quantized = ggml_convert_qtype(w, self.qtype, + device=device, + convert_shape_only=self.convert_shape_only) + self.data = w_quantized self.quantized = True self._shape = w.shape return self diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 36e985df..a4663145 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -111,7 +111,7 @@ class _BaseAutoModelClass: invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " - "fp4, fp8 or fp16.") + "fp4, fp8, fp16 or mixed_4bit.") qtype = ggml_tensor_qtype[q_k] # In case it needs a second try,