gguf memory optimization for mixtral (#9939)

This commit is contained in:
Shaojun Liu 2024-01-19 11:13:15 +08:00 committed by GitHub
parent 610b5226be
commit 967714bac8
3 changed files with 11 additions and 3 deletions

View file

@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
is_linear, linear_args = is_linear_module(module)
if is_linear and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
module.weight.data.device.type != 'meta' and
not isinstance(module, LowBitLinear)):
in_features, out_features, mp_group = linear_args
with init_empty_weights():
new_linear = None

View file

@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
low_bit='sym_int4'):
config = loader.config
@ -78,7 +79,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
else:
set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mistral)

View file

@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
from ..gguf import GGUFFileLoader
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
low_bit='sym_int4'):
# mixtral enjoys a general architecture of llma
# e.g. it applies llama tokenizer
config = loader.config
@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
n_head = config['llama.attention.head_count']
n_head_kv = config['llama.attention.head_count_kv']
hidden_size = config['llama.embedding_length']
qtype = ggml_tensor_qtype[low_bit]
mixtral_config = MixtralConfig(
vocab_size=len(config['tokenizer.ggml.tokens']),
@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
"cpu",
tensor,
dtype=dtype)
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mixtral)