From 038330668892623053952b3074966269923d63d7 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 20 Oct 2023 17:15:07 +0800 Subject: [PATCH] Add arc fp8 support (#9232) * add fp8 support * add log * fix style --- python/llm/src/bigdl/llm/ggml/quantize.py | 3 ++- python/llm/src/bigdl/llm/transformers/convert.py | 3 +++ .../src/bigdl/llm/transformers/low_bit_linear.py | 15 ++++++++++----- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/quantize.py b/python/llm/src/bigdl/llm/ggml/quantize.py index 579ee913..c162c694 100644 --- a/python/llm/src/bigdl/llm/ggml/quantize.py +++ b/python/llm/src/bigdl/llm/ggml/quantize.py @@ -32,7 +32,8 @@ ggml_tensor_qtype = {"sym_int4": 2, # q4_0 in ggml "sym_int8": 8, # q8_0 in ggml "nf4": 10, "nf3": 11, - "fp16": 12} + "fp16": 12, + "fp8": 15} _llama_quantize_type = {"q4_0": 2, "q4_1": 3, diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 30d9d0e3..c1902d35 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -146,6 +146,9 @@ def _optimize_pre(model): def ggml_convert_low_bit(model, qtype, optimize_model=True, convert_shape_only=False, device="cpu", modules_to_not_convert=None): + logger.info(f"Converting the current model to " + f"{list(ggml_tensor_qtype.keys())[list(ggml_tensor_qtype.values()).index(qtype)]} " + f"format......") modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert if optimize_model: diff --git a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py index 8c7392cc..b1026ec7 100644 --- a/python/llm/src/bigdl/llm/transformers/low_bit_linear.py +++ b/python/llm/src/bigdl/llm/transformers/low_bit_linear.py @@ -64,6 +64,7 @@ SYM_INT4 = ggml_tensor_qtype["sym_int4"] SYM_INT8 = ggml_tensor_qtype["sym_int8"] NF4 = ggml_tensor_qtype["nf4"] NF3 = ggml_tensor_qtype["nf3"] +FP8 = ggml_tensor_qtype["fp8"] def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, @@ -87,9 +88,13 @@ def ggml_convert_qtype(tensor: torch.Tensor, qtype: int, device=device) if not convert_shape_only and device != 'meta': - dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) - hist = (ctypes.c_int64 * 16)() - ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) + if qtype == FP8: + import linear_q4_0 + linear_q4_0.cvt_fp32_e4m3_rne(tensor, dst_tensor, n, k) + else: + dst = ctypes.c_void_p(dst_tensor.data.data_ptr()) + hist = (ctypes.c_int64 * 16)() + ggml.ggml_quantize_tensor(src, dst, qtype, n, k, hist) return dst_tensor @@ -378,8 +383,8 @@ class LowBitLinear(nn.Linear): else: # CPU logic # todo may need to set a different number on different platforms - invalidInputError(self.qtype != NF3 and self.qtype != NF4, - "NF3 and NF4 quantization are currently not supported on CPU") + invalidInputError(self.qtype != NF3 and self.qtype != NF4 and self.qtype != FP8, + "NF3, NF4 and FP8 quantization are currently not supported on CPU") if IS_SERVER and (not IS_SPR) and \ self.qtype == SYM_INT4 and x_2d.shape[0] >= TORCH_LINEAR_THRESHOLD: x0_fp32 = ggml_int4_convert_fp32(x0, self.weight_shape, self.weight_length)