gguf memory optimization for mixtral (#9939)

This commit is contained in:
Shaojun Liu 2024-01-19 11:13:15 +08:00 committed by GitHub
parent 610b5226be
commit 967714bac8
3 changed files with 11 additions and 3 deletions

View file

@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
is_linear, linear_args = is_linear_module(module) is_linear, linear_args = is_linear_module(module)
if is_linear and name not in modules_to_not_convert: if is_linear and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert` # Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
module.weight.data.device.type != 'meta' and
not isinstance(module, LowBitLinear)):
in_features, out_features, mp_group = linear_args in_features, out_features, mp_group = linear_args
with init_empty_weights(): with init_empty_weights():
new_linear = None new_linear = None

View file

@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
low_bit='sym_int4'): low_bit='sym_int4'):
config = loader.config config = loader.config

View file

@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
from ..gguf import GGUFFileLoader from ..gguf import GGUFFileLoader
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
low_bit='sym_int4'):
# mixtral enjoys a general architecture of llma # mixtral enjoys a general architecture of llma
# e.g. it applies llama tokenizer # e.g. it applies llama tokenizer
config = loader.config config = loader.config
@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
n_head = config['llama.attention.head_count'] n_head = config['llama.attention.head_count']
n_head_kv = config['llama.attention.head_count_kv'] n_head_kv = config['llama.attention.head_count_kv']
hidden_size = config['llama.embedding_length'] hidden_size = config['llama.embedding_length']
qtype = ggml_tensor_qtype[low_bit]
mixtral_config = MixtralConfig( mixtral_config = MixtralConfig(
vocab_size=len(config['tokenizer.ggml.tokens']), vocab_size=len(config['tokenizer.ggml.tokens']),
@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
"cpu", "cpu",
tensor, tensor,
dtype=dtype) dtype=dtype)
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
tensor_loader = loader.tensor_loader tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mixtral) tensor_loader.load_while_process(process_mixtral)