From a86c6e0b564b88138b8f021360c8f096f622e5ca Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 28 Nov 2023 15:51:15 +0800 Subject: [PATCH] [LLM] support loading gguf model (#9544) --- .../bigdl/llm/transformers/gguf/__init__.py | 17 + .../src/bigdl/llm/transformers/gguf/api.py | 44 +++ .../src/bigdl/llm/transformers/gguf/gguf.py | 318 ++++++++++++++++++ .../llm/transformers/gguf/models/__init__.py | 15 + .../llm/transformers/gguf/models/llama.py | 98 ++++++ .../llm/src/bigdl/llm/transformers/model.py | 20 ++ 6 files changed, 512 insertions(+) create mode 100644 python/llm/src/bigdl/llm/transformers/gguf/__init__.py create mode 100644 python/llm/src/bigdl/llm/transformers/gguf/api.py create mode 100644 python/llm/src/bigdl/llm/transformers/gguf/gguf.py create mode 100644 python/llm/src/bigdl/llm/transformers/gguf/models/__init__.py create mode 100644 python/llm/src/bigdl/llm/transformers/gguf/models/llama.py diff --git a/python/llm/src/bigdl/llm/transformers/gguf/__init__.py b/python/llm/src/bigdl/llm/transformers/gguf/__init__.py new file mode 100644 index 00000000..ab175ed9 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .api import load_gguf_model diff --git a/python/llm/src/bigdl/llm/transformers/gguf/api.py b/python/llm/src/bigdl/llm/transformers/gguf/api.py new file mode 100644 index 00000000..d2e3f389 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/api.py @@ -0,0 +1,44 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from bigdl.llm.utils.common import invalidInputError + + +qtype_map = { + 2: "sym_int4" # q4_0 +} + + +def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float): + from .gguf import GGUFFileLoader + + loader = GGUFFileLoader(fpath) + model_family = loader.config["general.architecture"] + qtype = loader.config["general.file_type"] + + invalidInputError(qtype in qtype_map, f"Unsupported gguf quantize type: {qtype}") + low_bit = qtype_map.get(qtype, "sym_int4") + + with torch.no_grad(): + if model_family == "llama": + from .models.llama import load_gguf_llama + + model = load_gguf_llama(loader, dtype) + else: + invalidInputError(False, f"Unsupported model family: {model_family}") + + return model, low_bit diff --git a/python/llm/src/bigdl/llm/transformers/gguf/gguf.py b/python/llm/src/bigdl/llm/transformers/gguf/gguf.py new file mode 100644 index 00000000..06b52276 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/gguf.py @@ -0,0 +1,318 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# see https://github.com/ggerganov/llama.cpp/blob/master/convert.py +# and https://github.com/ggerganov/llama.cpp/blob/master/ggml-quants.c +# and https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp + +import struct +import functools +import torch +import numpy + +from io import BufferedReader +from tqdm import tqdm +from bigdl.llm.utils.common import invalidInputError + + +class GGUFReader: + def __init__(self, f: BufferedReader): + self.f = f + self.funcs = { + 0: self.read_u8, + 1: self.read_i8, + 2: self.read_u16, + 3: self.read_i16, + 4: self.read_u32, + 5: self.read_i32, + 6: self.read_f32, + 7: self.read_bool, + 8: self.read_str, + 9: self.read_array, + 10: self.read_u64, + 11: self.read_i64, + 12: self.read_f64, + } + + def read_value(self): + value_type = self.read_i32() + value = self.funcs[value_type]() + return value + + def read_bool(self): + data = self.f.read(1) + return struct.unpack("> 4], dim=-1).view(torch.int8) - 8 + result = (data * scales).reshape(dims) + return result + + def convert_q4_1_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): + invalidInputError(False, "q4_1 conversion is not implemented") + + def convert_q6_k_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): + # see https://github.com/ggerganov/llama.cpp/blob + # /8e672efe632bb6a7333964a255c4b96f018b9a65/ggml-quants.c#L2263 + + block_size = self.block_size[14] + tensor = tensor.reshape((-1, block_size)) + + ql, qh, scales, d = (tensor[:, :128], tensor[:, 128:192], + tensor[:, 192:208], tensor[:, 208:]) + data_0 = (ql[:, 00:32] & 0xF) | ((qh[:, :32] & 0B00000011) << 4) + data_1 = (ql[:, 32:64] & 0xF) | ((qh[:, :32] & 0B00001100) << 2) + data_2 = (ql[:, 00:32] >> 4) | ((qh[:, :32] & 0B00110000) >> 0) + data_3 = (ql[:, 32:64] >> 4) | ((qh[:, :32] & 0B11000000) >> 2) + data_4 = (ql[:, 64:96] & 0xF) | ((qh[:, 32:64] & 0B00000011) << 4) + data_5 = (ql[:, 96:128] & 0xF) | ((qh[:, 32:64] & 0B00001100) << 2) + data_6 = (ql[:, 64:96] >> 4) | ((qh[:, 32:64] & 0B00110000) >> 0) + data_7 = (ql[:, 96:128] >> 4) | ((qh[:, 32:64] & 0B11000000) >> 2) + data = torch.cat([data_0, data_1, data_2, data_3, data_4, data_5, data_6, data_7], + dim=-1).view(torch.int8) - 32 + result = data * d.view(torch.half) + + result = result.reshape((-1, 16, 16)) * scales.view(torch.int8).reshape((-1, 16, 1)) + result = result.reshape(dims) + return result + + def convert_unknown_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): + invalidInputError(False, "Unsupported qtype") + + +class GGUFFileLoader: + def __init__(self, fpath: str): + with open(fpath, 'rb') as f: + header = GGUFHeader(f) + config = GGUFConfig(f, header) + tensor_infos = GGUFTensorInfos(f, header, config) + tensor_loader = GGUFTensorLoader(fpath, tensor_infos) + + self.header = header + self.config = config.config + self.tensor_loader = tensor_loader + + def tensors(self, dtype: torch.dtype = torch.float): + return { + name: value.to(dtype=dtype) + for name, value in self.tensor_loader + } + + def tensors_iter(self): + return self.tensor_loader diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/__init__.py b/python/llm/src/bigdl/llm/transformers/gguf/models/__init__.py new file mode 100644 index 00000000..2151a805 --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py b/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py new file mode 100644 index 00000000..00e9574e --- /dev/null +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/llama.py @@ -0,0 +1,98 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from transformers import LlamaConfig, LlamaForCausalLM + +from ..gguf import GGUFFileLoader + + +def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): + config = loader.config + + llama_config = LlamaConfig( + vocab_size=len(config['tokenizer.ggml.tokens']), + hidden_size=config['llama.embedding_length'], + intermediate_size=config['llama.feed_forward_length'], + num_hidden_layers=config['llama.block_count'], + num_attention_heads=config['llama.attention.head_count'], + num_key_value_heads=config['llama.attention.head_count_kv'], + hidden_act="silu", + max_position_embeddings=config['llama.context_length'], + rms_norm_eps=config['llama.attention.layer_norm_rms_epsilon'], + use_cache=True, + pad_token_id=None, + bos_token_id=config['tokenizer.ggml.bos_token_id'], + eos_token_id=config['tokenizer.ggml.eos_token_id'], + pretraining_tp=1, + ) + + ckpt = loader.tensors(dtype) + n_head = config['llama.attention.head_count'] + n_head_kv = config['llama.attention.head_count_kv'] + ckpt = restore_llama_weight(ckpt, n_head, n_head_kv) + + state_dict = {} + state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight'] + state_dict['model.norm.weight'] = ckpt['output_norm.weight'] + state_dict['lm_head.weight'] = ckpt['output.weight'] + for i in range(config['llama.block_count']): + state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \ + ckpt[f'blk.{i}.attn_q.weight'] + state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \ + ckpt[f'blk.{i}.attn_k.weight'] + state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \ + ckpt[f'blk.{i}.attn_v.weight'] + state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \ + ckpt[f'blk.{i}.attn_output.weight'] + state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \ + ckpt[f'blk.{i}.ffn_gate.weight'] + state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \ + ckpt[f'blk.{i}.ffn_up.weight'] + state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \ + ckpt[f'blk.{i}.ffn_down.weight'] + state_dict[f'model.layers.{i}.input_layernorm.weight'] = \ + ckpt[f'blk.{i}.attn_norm.weight'] + state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \ + ckpt[f'blk.{i}.ffn_norm.weight'] + + with init_empty_weights(): + model = LlamaForCausalLM(llama_config) + + for name, weight in state_dict.items(): + set_module_tensor_to_device(model, name, "cpu", weight) + + model = model.cpu() + + return model + + +def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int): + # see https://github.com/ggerganov/llama.cpp/blob + # /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978 + for name, weight in ckpt.items(): + head, hd_size = weight.shape[0], weight.shape[1:] + if name.endswith("attn_q.weight"): + ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size) + .swapaxes(1, 2) + .reshape(weight.shape)) + elif name.endswith("attn_k.weight"): + ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size) + .swapaxes(1, 2) + .reshape(weight.shape)) + return ckpt diff --git a/python/llm/src/bigdl/llm/transformers/model.py b/python/llm/src/bigdl/llm/transformers/model.py index 5f0bdfb2..f231a13c 100644 --- a/python/llm/src/bigdl/llm/transformers/model.py +++ b/python/llm/src/bigdl/llm/transformers/model.py @@ -44,6 +44,7 @@ from .utils import extract_local_archive_file, \ get_local_shard_files from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.gguf.api import load_gguf_model import torch import warnings import copy @@ -190,6 +191,25 @@ class _BaseAutoModelClass: return model + @staticmethod + def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False): + """ + Load a gguf model and convert it to bigdl-llm model + + :param fpath: Path to gguf model file + :param optimize_model: Whether to further optimize llm model, defaults to True + :param cpu_embedding: Whether to replace the Embedding layer, may need to set it + to `True` when running BigDL-LLM on GPU on Windows, defaults to False + + :return: An optimized bigdl-llm model + """ + from bigdl.llm.optimize import optimize_model as optimize_model_fn + + model, low_bit = load_gguf_model(fpath, dtype=torch.half) + model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model, + cpu_embedding=cpu_embedding) + return model + @classmethod def load_convert(cls, q_k, optimize_model, *args, **kwargs): from .convert import ggml_convert_low_bit