From 967714bac80e2ab15c6f166fcf51a02cb685608a Mon Sep 17 00:00:00 2001 From: Shaojun Liu <61072813+liu-shaojun@users.noreply.github.com> Date: Fri, 19 Jan 2024 11:13:15 +0800 Subject: [PATCH] gguf memory optimization for mixtral (#9939) --- python/llm/src/bigdl/llm/transformers/convert.py | 4 +++- .../llm/src/bigdl/llm/transformers/gguf/models/mistral.py | 3 ++- .../llm/src/bigdl/llm/transformers/gguf/models/mixtral.py | 7 ++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 441e7e4d..e3072179 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None, is_linear, linear_args = is_linear_module(module) if is_linear and name not in modules_to_not_convert: # Check if the current key is not in the `modules_to_not_convert` - if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and + module.weight.data.device.type != 'meta' and + not isinstance(module, LowBitLinear)): in_features, out_features, mp_group = linear_args with init_empty_weights(): new_linear = None diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py index 2c4ace53..ba1feae4 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py @@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module + def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, low_bit='sym_int4'): config = loader.config @@ -78,7 +79,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, else: set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype) model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) - + tensor_loader = loader.tensor_loader tensor_loader.load_while_process(process_mistral) diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py index 4ab221b8..216ceccb 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py @@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer from ..gguf import GGUFFileLoader +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module -def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): +def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, + low_bit='sym_int4'): # mixtral enjoys a general architecture of llma # e.g. it applies llama tokenizer config = loader.config @@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): n_head = config['llama.attention.head_count'] n_head_kv = config['llama.attention.head_count_kv'] hidden_size = config['llama.embedding_length'] + qtype = ggml_tensor_qtype[low_bit] mixtral_config = MixtralConfig( vocab_size=len(config['tokenizer.ggml.tokens']), @@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): "cpu", tensor, dtype=dtype) + model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) tensor_loader = loader.tensor_loader tensor_loader.load_while_process(process_mixtral)