gguf memory optimization for mixtral (#9939)
This commit is contained in:
parent
610b5226be
commit
967714bac8
3 changed files with 11 additions and 3 deletions
|
|
@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
|||
is_linear, linear_args = is_linear_module(module)
|
||||
if is_linear and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||
module.weight.data.device.type != 'meta' and
|
||||
not isinstance(module, LowBitLinear)):
|
||||
in_features, out_features, mp_group = linear_args
|
||||
with init_empty_weights():
|
||||
new_linear = None
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader
|
|||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
||||
|
||||
|
||||
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
||||
low_bit='sym_int4'):
|
||||
config = loader.config
|
||||
|
|
@ -78,7 +79,7 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
|||
else:
|
||||
set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
|
||||
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
|
||||
|
||||
|
||||
tensor_loader = loader.tensor_loader
|
||||
tensor_loader.load_while_process(process_mistral)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
|
|||
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
||||
|
||||
from ..gguf import GGUFFileLoader
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
||||
|
||||
|
||||
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
||||
low_bit='sym_int4'):
|
||||
# mixtral enjoys a general architecture of llma
|
||||
# e.g. it applies llama tokenizer
|
||||
config = loader.config
|
||||
|
|
@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
|||
n_head = config['llama.attention.head_count']
|
||||
n_head_kv = config['llama.attention.head_count_kv']
|
||||
hidden_size = config['llama.embedding_length']
|
||||
qtype = ggml_tensor_qtype[low_bit]
|
||||
|
||||
mixtral_config = MixtralConfig(
|
||||
vocab_size=len(config['tokenizer.ggml.tokens']),
|
||||
|
|
@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
|||
"cpu",
|
||||
tensor,
|
||||
dtype=dtype)
|
||||
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
|
||||
|
||||
tensor_loader = loader.tensor_loader
|
||||
tensor_loader.load_while_process(process_mixtral)
|
||||
|
|
|
|||
Loading…
Reference in a new issue