[LLM] support loading gguf model (#9544)

This commit is contained in:
Yishuo Wang 2023-11-28 15:51:15 +08:00 committed by GitHub
parent 32b37f3af7
commit a86c6e0b56
6 changed files with 512 additions and 0 deletions

View file

@ -0,0 +1,17 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from .api import load_gguf_model

View file

@ -0,0 +1,44 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from bigdl.llm.utils.common import invalidInputError
qtype_map = {
2: "sym_int4" # q4_0
}
def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
from .gguf import GGUFFileLoader
loader = GGUFFileLoader(fpath)
model_family = loader.config["general.architecture"]
qtype = loader.config["general.file_type"]
invalidInputError(qtype in qtype_map, f"Unsupported gguf quantize type: {qtype}")
low_bit = qtype_map.get(qtype, "sym_int4")
with torch.no_grad():
if model_family == "llama":
from .models.llama import load_gguf_llama
model = load_gguf_llama(loader, dtype)
else:
invalidInputError(False, f"Unsupported model family: {model_family}")
return model, low_bit

View file

@ -0,0 +1,318 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# see https://github.com/ggerganov/llama.cpp/blob/master/convert.py
# and https://github.com/ggerganov/llama.cpp/blob/master/ggml-quants.c
# and https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp
import struct
import functools
import torch
import numpy
from io import BufferedReader
from tqdm import tqdm
from bigdl.llm.utils.common import invalidInputError
class GGUFReader:
def __init__(self, f: BufferedReader):
self.f = f
self.funcs = {
0: self.read_u8,
1: self.read_i8,
2: self.read_u16,
3: self.read_i16,
4: self.read_u32,
5: self.read_i32,
6: self.read_f32,
7: self.read_bool,
8: self.read_str,
9: self.read_array,
10: self.read_u64,
11: self.read_i64,
12: self.read_f64,
}
def read_value(self):
value_type = self.read_i32()
value = self.funcs[value_type]()
return value
def read_bool(self):
data = self.f.read(1)
return struct.unpack("<?", data)[0]
def read_i8(self):
data = self.f.read(1)
return struct.unpack("<b", data)[0]
def read_u8(self):
data = self.f.read(1)
return struct.unpack("<B", data)[0]
def read_i16(self):
data = self.f.read(2)
return struct.unpack("<h", data)[0]
def read_u16(self):
data = self.f.read(2)
return struct.unpack("<H", data)[0]
def read_i32(self):
data = self.f.read(4)
return struct.unpack("<i", data)[0]
def read_u32(self):
data = self.f.read(4)
return struct.unpack("<I", data)[0]
def read_i64(self):
data = self.f.read(8)
return struct.unpack("<q", data)[0]
def read_u64(self):
data = self.f.read(8)
return struct.unpack("<Q", data)[0]
def read_f32(self):
data = self.f.read(4)
return struct.unpack("<f", data)[0]
def read_f64(self):
data = self.f.read(8)
return struct.unpack("<d", data)[0]
def read_str(self):
length = self.read_u64()
data = self.f.read(length)
return data.decode()
def read_array(self):
item_type = self.read_i32()
item_num = self.read_u64()
arr = [
self.funcs[item_type]()
for i in range(item_num)
]
return arr
class GGUFHeader:
size = 4 + 4 + 8 + 8
def __init__(self, f: BufferedReader):
data = f.read(GGUFHeader.size)
magic = data[0:4].decode()
invalidInputError(magic == "GGUF", "not a valid gguf file")
version, n_tensors, n_kv = struct.unpack("<IQQ", data[4:])
invalidInputError(version == 2, "only gguf v2 is supported")
self.magic = magic
self.version = version
self.n_tensors = n_tensors
self.n_kv = n_kv
class GGUFConfig:
def __init__(self, f: BufferedReader, header: GGUFHeader):
self.config = {}
reader = GGUFReader(f)
for i in range(header.n_kv):
key = reader.read_str()
value = reader.read_value()
self.config[key] = value
class GGUFTensorInfos:
def __init__(self, f: BufferedReader, header: GGUFHeader, config: GGUFConfig):
self.infos = []
reader = GGUFReader(f)
for i in range(header.n_tensors):
name = reader.read_str()
ndims = reader.read_u32()
dims = [
reader.read_u64()
for i in range(ndims)
]
dims = list(reversed(dims))
qtype = reader.read_i32()
offset = reader.read_u64()
self.infos.append((name, ndims, dims, qtype, offset))
alignment = config.config.get("general.alignment", 32)
base_offset = (f.tell() + alignment - 1) // alignment * alignment
self.base_offset = base_offset
class GGUFTensorLoader:
def __init__(self, fpath: str, tensor_infos: GGUFTensorInfos):
self.block_ne = {
0: 1, # f32
1: 1, # f16
2: 32, # q4_0
3: 32, # q4_1
6: 32, # q5_0
7: 32, # q5_1
8: 32, # q8_0
9: 32, # q8_1
10: 256, # q2_k
11: 256, # q3_k
12: 256, # q4_k
13: 256, # q5_k
14: 256, # q6_k
15: 256, # q8_k
16: 1, # i8
17: 1, # i16
18: 1, # i32
}
self.block_size = {
0: 4, # f32
1: 2, # f16
2: 18, # q4_0
3: 20, # q4_1
6: 22, # q5_0
7: 24, # q5_1
8: 34, # q8_0
9: 40, # q8_1
10: 0, # q2_k
11: 0, # q3_k
12: 0, # q4_k
13: 0, # q5_k
14: 210, # q6_k
15: 0, # q8_k
16: 1, # i8
17: 2, # i16
18: 4, # i32
}
self.convert_funcs = {
0: self.convert_f32_tensor, # f32
1: self.convert_f16_tensor, # f16
2: self.convert_q4_0_tensor, # q4_0
3: self.convert_q4_1_tensor, # q4_1
6: self.convert_unknown_tensor, # q5_0
7: self.convert_unknown_tensor, # q5_1
8: self.convert_unknown_tensor, # q8_0
9: self.convert_unknown_tensor, # q8_1
10: self.convert_unknown_tensor, # q2_k
11: self.convert_unknown_tensor, # q3_k
12: self.convert_unknown_tensor, # q4_k
13: self.convert_unknown_tensor, # q5_k
14: self.convert_q6_k_tensor, # q6_k
15: self.convert_unknown_tensor, # q8_k
16: self.convert_unknown_tensor, # i8
17: self.convert_unknown_tensor, # i16
18: self.convert_unknown_tensor, # i32
}
self.fpath = fpath
self.infos = tensor_infos.infos
self.base_offset = tensor_infos.base_offset
def __iter__(self):
with open(self.fpath, 'rb') as f:
for name, ndims, dims, qtype, offset in tqdm(self.infos, desc="Loading gguf tensors"):
total_ne = functools.reduce(lambda x, y: x * y, dims)
invalidInputError(total_ne % self.block_ne[qtype] == 0,
f"wrong elements num: {dims}")
size = total_ne // self.block_ne[qtype] * self.block_size[qtype]
invalidInputError(size != 0, f"unsupported quantize type: {qtype}")
offset += self.base_offset
f.seek(offset)
data = f.read(size)
arr = numpy.frombuffer(data, dtype=numpy.uint8).copy()
tensor = torch.from_numpy(arr)
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
yield name, tensor
def convert_f32_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
return tensor.view(torch.float)
def convert_f16_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
return tensor.view(torch.half)
def convert_q4_0_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
# see https://github.com/ggerganov/llama.cpp/blob
# /8e672efe632bb6a7333964a255c4b96f018b9a65/ggml-quants.c#L1074
block_size = self.block_size[2]
tensor = tensor.reshape((-1, block_size))
scales, data = tensor[:, :2], tensor[:, 2:]
scales = scales.view(torch.half)
data = torch.cat([data & 0xF, data >> 4], dim=-1).view(torch.int8) - 8
result = (data * scales).reshape(dims)
return result
def convert_q4_1_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
invalidInputError(False, "q4_1 conversion is not implemented")
def convert_q6_k_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
# see https://github.com/ggerganov/llama.cpp/blob
# /8e672efe632bb6a7333964a255c4b96f018b9a65/ggml-quants.c#L2263
block_size = self.block_size[14]
tensor = tensor.reshape((-1, block_size))
ql, qh, scales, d = (tensor[:, :128], tensor[:, 128:192],
tensor[:, 192:208], tensor[:, 208:])
data_0 = (ql[:, 00:32] & 0xF) | ((qh[:, :32] & 0B00000011) << 4)
data_1 = (ql[:, 32:64] & 0xF) | ((qh[:, :32] & 0B00001100) << 2)
data_2 = (ql[:, 00:32] >> 4) | ((qh[:, :32] & 0B00110000) >> 0)
data_3 = (ql[:, 32:64] >> 4) | ((qh[:, :32] & 0B11000000) >> 2)
data_4 = (ql[:, 64:96] & 0xF) | ((qh[:, 32:64] & 0B00000011) << 4)
data_5 = (ql[:, 96:128] & 0xF) | ((qh[:, 32:64] & 0B00001100) << 2)
data_6 = (ql[:, 64:96] >> 4) | ((qh[:, 32:64] & 0B00110000) >> 0)
data_7 = (ql[:, 96:128] >> 4) | ((qh[:, 32:64] & 0B11000000) >> 2)
data = torch.cat([data_0, data_1, data_2, data_3, data_4, data_5, data_6, data_7],
dim=-1).view(torch.int8) - 32
result = data * d.view(torch.half)
result = result.reshape((-1, 16, 16)) * scales.view(torch.int8).reshape((-1, 16, 1))
result = result.reshape(dims)
return result
def convert_unknown_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
invalidInputError(False, "Unsupported qtype")
class GGUFFileLoader:
def __init__(self, fpath: str):
with open(fpath, 'rb') as f:
header = GGUFHeader(f)
config = GGUFConfig(f, header)
tensor_infos = GGUFTensorInfos(f, header, config)
tensor_loader = GGUFTensorLoader(fpath, tensor_infos)
self.header = header
self.config = config.config
self.tensor_loader = tensor_loader
def tensors(self, dtype: torch.dtype = torch.float):
return {
name: value.to(dtype=dtype)
for name, value in self.tensor_loader
}
def tensors_iter(self):
return self.tensor_loader

View file

@ -0,0 +1,15 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View file

@ -0,0 +1,98 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from transformers import LlamaConfig, LlamaForCausalLM
from ..gguf import GGUFFileLoader
def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
config = loader.config
llama_config = LlamaConfig(
vocab_size=len(config['tokenizer.ggml.tokens']),
hidden_size=config['llama.embedding_length'],
intermediate_size=config['llama.feed_forward_length'],
num_hidden_layers=config['llama.block_count'],
num_attention_heads=config['llama.attention.head_count'],
num_key_value_heads=config['llama.attention.head_count_kv'],
hidden_act="silu",
max_position_embeddings=config['llama.context_length'],
rms_norm_eps=config['llama.attention.layer_norm_rms_epsilon'],
use_cache=True,
pad_token_id=None,
bos_token_id=config['tokenizer.ggml.bos_token_id'],
eos_token_id=config['tokenizer.ggml.eos_token_id'],
pretraining_tp=1,
)
ckpt = loader.tensors(dtype)
n_head = config['llama.attention.head_count']
n_head_kv = config['llama.attention.head_count_kv']
ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
state_dict = {}
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
state_dict['lm_head.weight'] = ckpt['output.weight']
for i in range(config['llama.block_count']):
state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
ckpt[f'blk.{i}.attn_q.weight']
state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
ckpt[f'blk.{i}.attn_k.weight']
state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
ckpt[f'blk.{i}.attn_v.weight']
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
ckpt[f'blk.{i}.attn_output.weight']
state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \
ckpt[f'blk.{i}.ffn_gate.weight']
state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \
ckpt[f'blk.{i}.ffn_up.weight']
state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \
ckpt[f'blk.{i}.ffn_down.weight']
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
ckpt[f'blk.{i}.attn_norm.weight']
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
ckpt[f'blk.{i}.ffn_norm.weight']
with init_empty_weights():
model = LlamaForCausalLM(llama_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight)
model = model.cpu()
return model
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
# see https://github.com/ggerganov/llama.cpp/blob
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
for name, weight in ckpt.items():
head, hd_size = weight.shape[0], weight.shape[1:]
if name.endswith("attn_q.weight"):
ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size)
.swapaxes(1, 2)
.reshape(weight.shape))
elif name.endswith("attn_k.weight"):
ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size)
.swapaxes(1, 2)
.reshape(weight.shape))
return ckpt

View file

@ -44,6 +44,7 @@ from .utils import extract_local_archive_file, \
get_local_shard_files
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.gguf.api import load_gguf_model
import torch
import warnings
import copy
@ -190,6 +191,25 @@ class _BaseAutoModelClass:
return model
@staticmethod
def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False):
"""
Load a gguf model and convert it to bigdl-llm model
:param fpath: Path to gguf model file
:param optimize_model: Whether to further optimize llm model, defaults to True
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
to `True` when running BigDL-LLM on GPU on Windows, defaults to False
:return: An optimized bigdl-llm model
"""
from bigdl.llm.optimize import optimize_model as optimize_model_fn
model, low_bit = load_gguf_model(fpath, dtype=torch.half)
model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model,
cpu_embedding=cpu_embedding)
return model
@classmethod
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
from .convert import ggml_convert_low_bit