From e2264e88452a1de5c93840c6e7bc90e61ae48ac4 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Wed, 25 Oct 2023 15:42:48 +0800 Subject: [PATCH] Support arc fp4 (#9266) * support arc fp4 * fix style * fix style --- python/llm/src/bigdl/llm/ggml/quantize.py | 3 ++- .../llm/src/bigdl/llm/transformers/low_bit_linear.py | 11 +++++++---- python/llm/src/bigdl/llm/transformers/model.py | 11 ++++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index c162c694..4f8891e9 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -33,7 +33,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "nf4": 10, "nf3": 11, "fp16": 12, - "fp8": 15} + "fp8": 15, + "fp4": 16} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index c5b85312..f9bac244 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -65,6 +65,7 @@ SYM_INT8 = ggml_tensor_qtype["sym_int8"] NF4 = ggml_tensor_qtype["nf4"] NF3 = ggml_tensor_qtype["nf3"] FP8 = ggml_tensor_qtype["fp8"] +FP4 = ggml_tensor_qtype["fp4"] def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, @@ -108,7 +109,7 @@ def ggml_q_format_convet_cpu2xpu(tensor: torch.Tensor, num_elem: int, qtype: int src = ctypes.c_void_p(tensor.data.data_ptr()) - if qtype in [SYM_INT4, SYM_INT8, NF4, NF3]: + if qtype in [SYM_INT4, SYM_INT8, NF4, NF3, FP4]: dst_tensor = torch.empty_like(tensor) elif qtype == ggml_tensor_qtype["sym_int5"]: QK = ggml.ggml_qk_size(qtype) @@ -133,7 +134,7 @@ def ggml_q_format_convet_xpu2cpu(tensor: torch.Tensor, num_elem: int, qtype: int src = ctypes.c_void_p(tensor.data.data_ptr()) - if qtype in [SYM_INT4, SYM_INT8, NF4, NF3]: + if qtype in [SYM_INT4, SYM_INT8, NF4, NF3, FP4]: dst_tensor = torch.empty_like(tensor) elif qtype == ggml_tensor_qtype["sym_int5"]: QK = ggml.ggml_qk_size(ggml_tensor_qtype["asym_int5"]) @@ -387,8 +388,10 @@ class LowBitLinear(nn.Linear): else: # CPU logic # todo may need to set a different number on different platforms - invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8, - "NF3, NF4 and FP8 quantization are currently not supported on CPU") + invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8 + and self.qtype != FP4, + "NF3, NF4, FP4 and FP8 quantization are currently not" + " supported on CPU") if IS_SERVER and (not IS_SPR) and \ self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length) diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 3f34c45f..e498987f 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -60,9 +60,10 @@ class _BaseAutoModelClass: :param load_in_4bit: boolean value, True means load linear's weight to symmetric int 4. Default to be False. :param load_in_low_bit: str value, options are sym_int4, asym_int4, sym_int5, asym_int5 - , sym_int8, nf3, nf4 or fp16. sym_int4 means symmetric int 4, - asym_int4 means asymmetric int 4, nf4 means 4-bit NormalFloat, etc. - Relevant low bit optimizations will be applied to the model. + , sym_int8, nf3, nf4, fp4, fp8 or fp16. sym_int4 means symmetric + int 4, asym_int4 means asymmetric int 4, nf4 means 4-bit + NormalFloat, etc. Relevant low bit optimizations will be applied + to the model. :param optimize_model: boolean value, Whether to further optimize the low_bit llm model. Default to be True. :param modules_to_not_convert: list of str value, modules (nn.Module) that are skipped when @@ -106,8 +107,8 @@ class _BaseAutoModelClass: from .convert import ggml_convert_low_bit invalidInputError(q_k in ggml_tensor_qtype, f"Unknown load_in_low_bit value: {q_k}, expected:" - f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4 " - "or fp16.") + f" sym_int4, asym_int4, sym_int5, asym_int5, sym_int8, nf3, nf4, " + "fp4, fp8 or fp16.") qtype = ggml_tensor_qtype[q_k] # In case it needs a second try,