diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 38904d58..b161f3b9 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -35,7 +35,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "fp16": 12, "fp8": 15, "fp4": 16, - "mixed_4bit": 17} # Mixture of Formats Quantization 4 bits + "mixed_fp4": 17, # Mixture of Formats Quantization 4 bits + "mixed_fp8": 18} # Mixture of Formats Quantization 8 bits _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 813cc12d..f232558f 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -66,7 +66,8 @@ NF4 = ggml_tensor_qtype["nf4"] NF3 = ggml_tensor_qtype["nf3"] FP8 = ggml_tensor_qtype["fp8"] FP4 = ggml_tensor_qtype["fp4"] -MOFQ4 = ggml_tensor_qtype["mixed_4bit"] +MOFQ4 = ggml_tensor_qtype["mixed_fp4"] +MOFQ8 = ggml_tensor_qtype["mixed_fp8"] def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, @@ -194,6 +195,16 @@ class FP4Params(torch.nn.Parameter): self.convert_shape_only = convert_shape_only return self + def ggml_mse(self, w, ggml_qtype, device): + from torch.nn.functional import mse_loss + w_quant = ggml_convert_qtype(w, ggml_qtype, + device=device, + convert_shape_only=self.convert_shape_only) + w_dequant = ggml_convert_fp32(w_quant, w.shape, + reduce(mul, w.shape, 1), ggml_qtype) + mse = mse_loss(w_dequant, w) + return mse, w_quant + def quantize(self, device=None): if not self.quantized: w = self.data.contiguous().float() @@ -206,25 +217,31 @@ class FP4Params(torch.nn.Parameter): # save/load self.qtype = SYM_INT4 else: - from torch.nn.functional import mse_loss - w_quant_q4_0 = ggml_convert_qtype(w, SYM_INT4, - device=device, - convert_shape_only=self.convert_shape_only) - w_q4_0_dequant = ggml_convert_fp32(w_quant_q4_0, w.shape, - reduce(mul, w.shape, 1), SYM_INT4) - w_quant_fp4 = ggml_convert_qtype(w, FP4, - device=device, - convert_shape_only=self.convert_shape_only) - w_fp4_dequant = ggml_convert_fp32(w_quant_fp4, w.shape, - reduce(mul, w.shape, 1), FP4) - q4_0_mse = mse_loss(w_q4_0_dequant, w) - fp4_mse = mse_loss(w_fp4_dequant, w) + q4_0_mse, w_quant_q4_0 = self.ggml_mse(w, SYM_INT4, device=device) + fp4_mse, w_quant_fp4 = self.ggml_mse(w, FP4, device=device) if q4_0_mse <= fp4_mse: self.qtype = SYM_INT4 self.data = w_quant_q4_0 else: self.qtype = FP4 self.data = w_quant_fp4 + elif self.qtype == MOFQ8: + if device == 'meta': + w_quantized = ggml_convert_qtype(w, SYM_INT8, + device=device, + convert_shape_only=self.convert_shape_only) + # TODO: should load from config, the current implementation doesn't support + # save/load + self.qtype = SYM_INT8 + else: + q8_0_mse, w_quant_q8_0 = self.ggml_mse(w, SYM_INT8, device=device) + fp8_mse, w_quant_fp8 = self.ggml_mse(w, FP8, device=device) + if q8_0_mse <= fp8_mse: + self.qtype = SYM_INT8 + self.data = w_quant_q8_0 + else: + self.qtype = FP8 + self.data = w_quant_fp8 else: w_quantized = ggml_convert_qtype(w, self.qtype, device=device, diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 98a54ee1..93a42876 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -113,7 +113,7 @@ class _BaseAutoModelClass: invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " - "fp4, fp8, fp16 or mixed_4bit.") + "fp4, fp8, fp16, mixed_fp4 or mixed_fp8.") qtype = ggml_tensor_qtype[q_k] # In case it needs a second try, diff --git a/python/llm/src/bigdl/llm/utils/xmx_checker.py b/python/llm/src/bigdl/llm/utils/xmx_checker.py index e26916eb..4bef2f9b 100644 --- a/python/llm/src/bigdl/llm/utils/xmx_checker.py +++ b/python/llm/src/bigdl/llm/utils/xmx_checker.py @@ -25,7 +25,8 @@ NF4 = ggml_tensor_qtype["nf4"] NF3 = ggml_tensor_qtype["nf3"] FP8 = ggml_tensor_qtype["fp8"] FP4 = ggml_tensor_qtype["fp4"] -MOFQ4 = ggml_tensor_qtype["mixed_4bit"] +MOFQ4 = ggml_tensor_qtype["mixed_fp4"] +MOFQ8 = ggml_tensor_qtype["mixed_fp8"] class XMXChecker: