From 65121c79976e295f0a0bf9a00fd9e9073cf94d5e Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 29 Nov 2023 14:40:37 +0800 Subject: [PATCH] support loading q4_1/q5_0/q5_1/q8_0 gguf model (#9546) --- .../src/bigdl/llm/transformers/gguf/api.py | 6 +- .../src/bigdl/llm/transformers/gguf/gguf.py | 66 +++++++++++++++++-- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/gguf/api.py b/python/llm/src/bigdl/llm/transformers/gguf/api.py index d2e3f389..0ee72d5a 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/api.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/api.py @@ -19,7 +19,11 @@ from bigdl.llm.utils.common import invalidInputError qtype_map = { - 2: "sym_int4" # q4_0 + 2: "sym_int4", # q4_0 + 3: "asym_int4", # q4_1 + 7: "sym_int8", # q8_0 + 8: "sym_int5", # q5_0 + 9: "asym_int5", # q5_1 } diff --git a/python/llm/src/bigdl/llm/transformers/gguf/gguf.py b/python/llm/src/bigdl/llm/transformers/gguf/gguf.py index 06b52276..6fe34a5c 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/gguf.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/gguf.py @@ -121,7 +121,7 @@ class GGUFHeader: invalidInputError(magic == "GGUF", "not a valid gguf file") version, n_tensors, n_kv = struct.unpack("> 4], dim=-1) + result = (data * scales + base).reshape(dims) + return result + + def convert_q5_0_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): + # see https://github.com/ggerganov/llama.cpp/blob + # /b38a16dfcff88d547f78f52d1bea31b84a05aff7/ggml-quants.c#L1115 + + block_size = self.block_size[6] + tensor = tensor.reshape((-1, block_size)) + scales, hdata, ldata = tensor[:, :2], tensor[:, 2:6], tensor[:, 6:] + scales = scales.view(torch.half) + # hdata = hdata.view(torch.int) + hdata = hdata.clone().view(torch.int) # clone hdata to fix memory address alignment + shift = torch.arange(0, 32, 1) + hdata = (((hdata.expand(-1, 32) >> shift) << 4) & 0x10).byte() + ldata = torch.cat([ldata & 0xF, ldata >> 4], dim=-1) + data = (hdata | ldata).view(torch.int8) - 16 + result = (data * scales).reshape(dims) + return result + + def convert_q5_1_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): + # https://github.com/ggerganov/llama.cpp/blob + # /b38a16dfcff88d547f78f52d1bea31b84a05aff7/ggml-quants.c#L1141 + + block_size = self.block_size[7] + tensor = tensor.reshape((-1, block_size)) + scales, base, hdata, ldata = tensor[:, :2], tensor[:, 2:4], tensor[:, 4:8], tensor[:, 8:] + scales = scales.view(torch.half) + base = base.view(torch.half) + hdata = hdata.view(torch.int) + shift = torch.arange(0, 32, 1) + hdata = (((hdata.expand(-1, 32) >> shift) << 4) & 0x10).byte() + ldata = torch.cat([ldata & 0xF, ldata >> 4], dim=-1) + data = hdata | ldata + result = (data * scales + base).reshape(dims) + return result + + def convert_q8_0_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): + # https://github.com/ggerganov/llama.cpp/blob + # /b38a16dfcff88d547f78f52d1bea31b84a05aff7/ggml-quants.c#L1168 + + block_size = self.block_size[8] + tensor = tensor.reshape((-1, block_size)) + scales, data = tensor[:, :2], tensor[:, 2:] + scales = scales.view(torch.half) + data = data.view(torch.int8) + result = (data * scales).reshape(dims) + return result def convert_q6_k_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): # see https://github.com/ggerganov/llama.cpp/blob