[LLM] support loading gguf model (#9544)
This commit is contained in:
parent
32b37f3af7
commit
a86c6e0b56
6 changed files with 512 additions and 0 deletions
17
python/llm/src/bigdl/llm/transformers/gguf/__init__.py
Normal file
17
python/llm/src/bigdl/llm/transformers/gguf/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from .api import load_gguf_model
|
||||
44
python/llm/src/bigdl/llm/transformers/gguf/api.py
Normal file
44
python/llm/src/bigdl/llm/transformers/gguf/api.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
qtype_map = {
|
||||
2: "sym_int4" # q4_0
|
||||
}
|
||||
|
||||
|
||||
def load_gguf_model(fpath: str, dtype: torch.dtype = torch.float):
|
||||
from .gguf import GGUFFileLoader
|
||||
|
||||
loader = GGUFFileLoader(fpath)
|
||||
model_family = loader.config["general.architecture"]
|
||||
qtype = loader.config["general.file_type"]
|
||||
|
||||
invalidInputError(qtype in qtype_map, f"Unsupported gguf quantize type: {qtype}")
|
||||
low_bit = qtype_map.get(qtype, "sym_int4")
|
||||
|
||||
with torch.no_grad():
|
||||
if model_family == "llama":
|
||||
from .models.llama import load_gguf_llama
|
||||
|
||||
model = load_gguf_llama(loader, dtype)
|
||||
else:
|
||||
invalidInputError(False, f"Unsupported model family: {model_family}")
|
||||
|
||||
return model, low_bit
|
||||
318
python/llm/src/bigdl/llm/transformers/gguf/gguf.py
Normal file
318
python/llm/src/bigdl/llm/transformers/gguf/gguf.py
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# see https://github.com/ggerganov/llama.cpp/blob/master/convert.py
|
||||
# and https://github.com/ggerganov/llama.cpp/blob/master/ggml-quants.c
|
||||
# and https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp
|
||||
|
||||
import struct
|
||||
import functools
|
||||
import torch
|
||||
import numpy
|
||||
|
||||
from io import BufferedReader
|
||||
from tqdm import tqdm
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
class GGUFReader:
|
||||
def __init__(self, f: BufferedReader):
|
||||
self.f = f
|
||||
self.funcs = {
|
||||
0: self.read_u8,
|
||||
1: self.read_i8,
|
||||
2: self.read_u16,
|
||||
3: self.read_i16,
|
||||
4: self.read_u32,
|
||||
5: self.read_i32,
|
||||
6: self.read_f32,
|
||||
7: self.read_bool,
|
||||
8: self.read_str,
|
||||
9: self.read_array,
|
||||
10: self.read_u64,
|
||||
11: self.read_i64,
|
||||
12: self.read_f64,
|
||||
}
|
||||
|
||||
def read_value(self):
|
||||
value_type = self.read_i32()
|
||||
value = self.funcs[value_type]()
|
||||
return value
|
||||
|
||||
def read_bool(self):
|
||||
data = self.f.read(1)
|
||||
return struct.unpack("<?", data)[0]
|
||||
|
||||
def read_i8(self):
|
||||
data = self.f.read(1)
|
||||
return struct.unpack("<b", data)[0]
|
||||
|
||||
def read_u8(self):
|
||||
data = self.f.read(1)
|
||||
return struct.unpack("<B", data)[0]
|
||||
|
||||
def read_i16(self):
|
||||
data = self.f.read(2)
|
||||
return struct.unpack("<h", data)[0]
|
||||
|
||||
def read_u16(self):
|
||||
data = self.f.read(2)
|
||||
return struct.unpack("<H", data)[0]
|
||||
|
||||
def read_i32(self):
|
||||
data = self.f.read(4)
|
||||
return struct.unpack("<i", data)[0]
|
||||
|
||||
def read_u32(self):
|
||||
data = self.f.read(4)
|
||||
return struct.unpack("<I", data)[0]
|
||||
|
||||
def read_i64(self):
|
||||
data = self.f.read(8)
|
||||
return struct.unpack("<q", data)[0]
|
||||
|
||||
def read_u64(self):
|
||||
data = self.f.read(8)
|
||||
return struct.unpack("<Q", data)[0]
|
||||
|
||||
def read_f32(self):
|
||||
data = self.f.read(4)
|
||||
return struct.unpack("<f", data)[0]
|
||||
|
||||
def read_f64(self):
|
||||
data = self.f.read(8)
|
||||
return struct.unpack("<d", data)[0]
|
||||
|
||||
def read_str(self):
|
||||
length = self.read_u64()
|
||||
data = self.f.read(length)
|
||||
return data.decode()
|
||||
|
||||
def read_array(self):
|
||||
item_type = self.read_i32()
|
||||
item_num = self.read_u64()
|
||||
arr = [
|
||||
self.funcs[item_type]()
|
||||
for i in range(item_num)
|
||||
]
|
||||
return arr
|
||||
|
||||
|
||||
class GGUFHeader:
|
||||
size = 4 + 4 + 8 + 8
|
||||
|
||||
def __init__(self, f: BufferedReader):
|
||||
data = f.read(GGUFHeader.size)
|
||||
|
||||
magic = data[0:4].decode()
|
||||
invalidInputError(magic == "GGUF", "not a valid gguf file")
|
||||
|
||||
version, n_tensors, n_kv = struct.unpack("<IQQ", data[4:])
|
||||
invalidInputError(version == 2, "only gguf v2 is supported")
|
||||
|
||||
self.magic = magic
|
||||
self.version = version
|
||||
self.n_tensors = n_tensors
|
||||
self.n_kv = n_kv
|
||||
|
||||
|
||||
class GGUFConfig:
|
||||
def __init__(self, f: BufferedReader, header: GGUFHeader):
|
||||
self.config = {}
|
||||
|
||||
reader = GGUFReader(f)
|
||||
for i in range(header.n_kv):
|
||||
key = reader.read_str()
|
||||
value = reader.read_value()
|
||||
self.config[key] = value
|
||||
|
||||
|
||||
class GGUFTensorInfos:
|
||||
def __init__(self, f: BufferedReader, header: GGUFHeader, config: GGUFConfig):
|
||||
self.infos = []
|
||||
|
||||
reader = GGUFReader(f)
|
||||
for i in range(header.n_tensors):
|
||||
name = reader.read_str()
|
||||
ndims = reader.read_u32()
|
||||
dims = [
|
||||
reader.read_u64()
|
||||
for i in range(ndims)
|
||||
]
|
||||
dims = list(reversed(dims))
|
||||
qtype = reader.read_i32()
|
||||
offset = reader.read_u64()
|
||||
self.infos.append((name, ndims, dims, qtype, offset))
|
||||
|
||||
alignment = config.config.get("general.alignment", 32)
|
||||
base_offset = (f.tell() + alignment - 1) // alignment * alignment
|
||||
self.base_offset = base_offset
|
||||
|
||||
|
||||
class GGUFTensorLoader:
|
||||
def __init__(self, fpath: str, tensor_infos: GGUFTensorInfos):
|
||||
self.block_ne = {
|
||||
0: 1, # f32
|
||||
1: 1, # f16
|
||||
2: 32, # q4_0
|
||||
3: 32, # q4_1
|
||||
6: 32, # q5_0
|
||||
7: 32, # q5_1
|
||||
8: 32, # q8_0
|
||||
9: 32, # q8_1
|
||||
10: 256, # q2_k
|
||||
11: 256, # q3_k
|
||||
12: 256, # q4_k
|
||||
13: 256, # q5_k
|
||||
14: 256, # q6_k
|
||||
15: 256, # q8_k
|
||||
16: 1, # i8
|
||||
17: 1, # i16
|
||||
18: 1, # i32
|
||||
}
|
||||
|
||||
self.block_size = {
|
||||
0: 4, # f32
|
||||
1: 2, # f16
|
||||
2: 18, # q4_0
|
||||
3: 20, # q4_1
|
||||
6: 22, # q5_0
|
||||
7: 24, # q5_1
|
||||
8: 34, # q8_0
|
||||
9: 40, # q8_1
|
||||
10: 0, # q2_k
|
||||
11: 0, # q3_k
|
||||
12: 0, # q4_k
|
||||
13: 0, # q5_k
|
||||
14: 210, # q6_k
|
||||
15: 0, # q8_k
|
||||
16: 1, # i8
|
||||
17: 2, # i16
|
||||
18: 4, # i32
|
||||
}
|
||||
|
||||
self.convert_funcs = {
|
||||
0: self.convert_f32_tensor, # f32
|
||||
1: self.convert_f16_tensor, # f16
|
||||
2: self.convert_q4_0_tensor, # q4_0
|
||||
3: self.convert_q4_1_tensor, # q4_1
|
||||
6: self.convert_unknown_tensor, # q5_0
|
||||
7: self.convert_unknown_tensor, # q5_1
|
||||
8: self.convert_unknown_tensor, # q8_0
|
||||
9: self.convert_unknown_tensor, # q8_1
|
||||
10: self.convert_unknown_tensor, # q2_k
|
||||
11: self.convert_unknown_tensor, # q3_k
|
||||
12: self.convert_unknown_tensor, # q4_k
|
||||
13: self.convert_unknown_tensor, # q5_k
|
||||
14: self.convert_q6_k_tensor, # q6_k
|
||||
15: self.convert_unknown_tensor, # q8_k
|
||||
16: self.convert_unknown_tensor, # i8
|
||||
17: self.convert_unknown_tensor, # i16
|
||||
18: self.convert_unknown_tensor, # i32
|
||||
}
|
||||
|
||||
self.fpath = fpath
|
||||
self.infos = tensor_infos.infos
|
||||
self.base_offset = tensor_infos.base_offset
|
||||
|
||||
def __iter__(self):
|
||||
with open(self.fpath, 'rb') as f:
|
||||
for name, ndims, dims, qtype, offset in tqdm(self.infos, desc="Loading gguf tensors"):
|
||||
total_ne = functools.reduce(lambda x, y: x * y, dims)
|
||||
invalidInputError(total_ne % self.block_ne[qtype] == 0,
|
||||
f"wrong elements num: {dims}")
|
||||
|
||||
size = total_ne // self.block_ne[qtype] * self.block_size[qtype]
|
||||
invalidInputError(size != 0, f"unsupported quantize type: {qtype}")
|
||||
|
||||
offset += self.base_offset
|
||||
f.seek(offset)
|
||||
data = f.read(size)
|
||||
arr = numpy.frombuffer(data, dtype=numpy.uint8).copy()
|
||||
tensor = torch.from_numpy(arr)
|
||||
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
|
||||
yield name, tensor
|
||||
|
||||
def convert_f32_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||
return tensor.view(torch.float)
|
||||
|
||||
def convert_f16_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||
return tensor.view(torch.half)
|
||||
|
||||
def convert_q4_0_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||
# see https://github.com/ggerganov/llama.cpp/blob
|
||||
# /8e672efe632bb6a7333964a255c4b96f018b9a65/ggml-quants.c#L1074
|
||||
|
||||
block_size = self.block_size[2]
|
||||
tensor = tensor.reshape((-1, block_size))
|
||||
scales, data = tensor[:, :2], tensor[:, 2:]
|
||||
scales = scales.view(torch.half)
|
||||
data = torch.cat([data & 0xF, data >> 4], dim=-1).view(torch.int8) - 8
|
||||
result = (data * scales).reshape(dims)
|
||||
return result
|
||||
|
||||
def convert_q4_1_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||
invalidInputError(False, "q4_1 conversion is not implemented")
|
||||
|
||||
def convert_q6_k_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||
# see https://github.com/ggerganov/llama.cpp/blob
|
||||
# /8e672efe632bb6a7333964a255c4b96f018b9a65/ggml-quants.c#L2263
|
||||
|
||||
block_size = self.block_size[14]
|
||||
tensor = tensor.reshape((-1, block_size))
|
||||
|
||||
ql, qh, scales, d = (tensor[:, :128], tensor[:, 128:192],
|
||||
tensor[:, 192:208], tensor[:, 208:])
|
||||
data_0 = (ql[:, 00:32] & 0xF) | ((qh[:, :32] & 0B00000011) << 4)
|
||||
data_1 = (ql[:, 32:64] & 0xF) | ((qh[:, :32] & 0B00001100) << 2)
|
||||
data_2 = (ql[:, 00:32] >> 4) | ((qh[:, :32] & 0B00110000) >> 0)
|
||||
data_3 = (ql[:, 32:64] >> 4) | ((qh[:, :32] & 0B11000000) >> 2)
|
||||
data_4 = (ql[:, 64:96] & 0xF) | ((qh[:, 32:64] & 0B00000011) << 4)
|
||||
data_5 = (ql[:, 96:128] & 0xF) | ((qh[:, 32:64] & 0B00001100) << 2)
|
||||
data_6 = (ql[:, 64:96] >> 4) | ((qh[:, 32:64] & 0B00110000) >> 0)
|
||||
data_7 = (ql[:, 96:128] >> 4) | ((qh[:, 32:64] & 0B11000000) >> 2)
|
||||
data = torch.cat([data_0, data_1, data_2, data_3, data_4, data_5, data_6, data_7],
|
||||
dim=-1).view(torch.int8) - 32
|
||||
result = data * d.view(torch.half)
|
||||
|
||||
result = result.reshape((-1, 16, 16)) * scales.view(torch.int8).reshape((-1, 16, 1))
|
||||
result = result.reshape(dims)
|
||||
return result
|
||||
|
||||
def convert_unknown_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||
invalidInputError(False, "Unsupported qtype")
|
||||
|
||||
|
||||
class GGUFFileLoader:
|
||||
def __init__(self, fpath: str):
|
||||
with open(fpath, 'rb') as f:
|
||||
header = GGUFHeader(f)
|
||||
config = GGUFConfig(f, header)
|
||||
tensor_infos = GGUFTensorInfos(f, header, config)
|
||||
tensor_loader = GGUFTensorLoader(fpath, tensor_infos)
|
||||
|
||||
self.header = header
|
||||
self.config = config.config
|
||||
self.tensor_loader = tensor_loader
|
||||
|
||||
def tensors(self, dtype: torch.dtype = torch.float):
|
||||
return {
|
||||
name: value.to(dtype=dtype)
|
||||
for name, value in self.tensor_loader
|
||||
}
|
||||
|
||||
def tensors_iter(self):
|
||||
return self.tensor_loader
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
98
python/llm/src/bigdl/llm/transformers/gguf/models/llama.py
Normal file
98
python/llm/src/bigdl/llm/transformers/gguf/models/llama.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..gguf import GGUFFileLoader
|
||||
|
||||
|
||||
def load_gguf_llama(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||
config = loader.config
|
||||
|
||||
llama_config = LlamaConfig(
|
||||
vocab_size=len(config['tokenizer.ggml.tokens']),
|
||||
hidden_size=config['llama.embedding_length'],
|
||||
intermediate_size=config['llama.feed_forward_length'],
|
||||
num_hidden_layers=config['llama.block_count'],
|
||||
num_attention_heads=config['llama.attention.head_count'],
|
||||
num_key_value_heads=config['llama.attention.head_count_kv'],
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=config['llama.context_length'],
|
||||
rms_norm_eps=config['llama.attention.layer_norm_rms_epsilon'],
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=config['tokenizer.ggml.bos_token_id'],
|
||||
eos_token_id=config['tokenizer.ggml.eos_token_id'],
|
||||
pretraining_tp=1,
|
||||
)
|
||||
|
||||
ckpt = loader.tensors(dtype)
|
||||
n_head = config['llama.attention.head_count']
|
||||
n_head_kv = config['llama.attention.head_count_kv']
|
||||
ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
|
||||
|
||||
state_dict = {}
|
||||
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
|
||||
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
|
||||
state_dict['lm_head.weight'] = ckpt['output.weight']
|
||||
for i in range(config['llama.block_count']):
|
||||
state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_q.weight']
|
||||
state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_k.weight']
|
||||
state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_v.weight']
|
||||
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_output.weight']
|
||||
state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_gate.weight']
|
||||
state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_up.weight']
|
||||
state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_down.weight']
|
||||
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_norm.weight']
|
||||
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_norm.weight']
|
||||
|
||||
with init_empty_weights():
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
|
||||
for name, weight in state_dict.items():
|
||||
set_module_tensor_to_device(model, name, "cpu", weight)
|
||||
|
||||
model = model.cpu()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def restore_llama_weight(ckpt: dict, n_head: int, n_head_kv: int):
|
||||
# see https://github.com/ggerganov/llama.cpp/blob
|
||||
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
|
||||
for name, weight in ckpt.items():
|
||||
head, hd_size = weight.shape[0], weight.shape[1:]
|
||||
if name.endswith("attn_q.weight"):
|
||||
ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weight.shape))
|
||||
elif name.endswith("attn_k.weight"):
|
||||
ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weight.shape))
|
||||
return ckpt
|
||||
|
|
@ -44,6 +44,7 @@ from .utils import extract_local_archive_file, \
|
|||
get_local_shard_files
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.gguf.api import load_gguf_model
|
||||
import torch
|
||||
import warnings
|
||||
import copy
|
||||
|
|
@ -190,6 +191,25 @@ class _BaseAutoModelClass:
|
|||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(fpath: str, optimize_model: bool = True, cpu_embedding: bool = False):
|
||||
"""
|
||||
Load a gguf model and convert it to bigdl-llm model
|
||||
|
||||
:param fpath: Path to gguf model file
|
||||
:param optimize_model: Whether to further optimize llm model, defaults to True
|
||||
:param cpu_embedding: Whether to replace the Embedding layer, may need to set it
|
||||
to `True` when running BigDL-LLM on GPU on Windows, defaults to False
|
||||
|
||||
:return: An optimized bigdl-llm model
|
||||
"""
|
||||
from bigdl.llm.optimize import optimize_model as optimize_model_fn
|
||||
|
||||
model, low_bit = load_gguf_model(fpath, dtype=torch.half)
|
||||
model = optimize_model_fn(model, low_bit=low_bit, optimize_llm=optimize_model,
|
||||
cpu_embedding=cpu_embedding)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def load_convert(cls, q_k, optimize_model, *args, **kwargs):
|
||||
from .convert import ggml_convert_low_bit
|
||||
|
|
|
|||
Loading…
Reference in a new issue