gguf memory optimization for mixtral (#9939)
This commit is contained in:
parent
610b5226be
commit
967714bac8
3 changed files with 11 additions and 3 deletions
|
|
@ -202,7 +202,9 @@ def _replace_with_low_bit_linear(model, qtype, modules_to_not_convert=None,
|
||||||
is_linear, linear_args = is_linear_module(module)
|
is_linear, linear_args = is_linear_module(module)
|
||||||
if is_linear and name not in modules_to_not_convert:
|
if is_linear and name not in modules_to_not_convert:
|
||||||
# Check if the current key is not in the `modules_to_not_convert`
|
# Check if the current key is not in the `modules_to_not_convert`
|
||||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
if (not any(key in ".".join(current_key_name) for key in modules_to_not_convert) and
|
||||||
|
module.weight.data.device.type != 'meta' and
|
||||||
|
not isinstance(module, LowBitLinear)):
|
||||||
in_features, out_features, mp_group = linear_args
|
in_features, out_features, mp_group = linear_args
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
new_linear = None
|
new_linear = None
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from ..gguf import GGUFFileLoader
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
||||||
|
|
||||||
|
|
||||||
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
||||||
low_bit='sym_int4'):
|
low_bit='sym_int4'):
|
||||||
config = loader.config
|
config = loader.config
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,12 @@ from tempfile import NamedTemporaryFile
|
||||||
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
from ..gguf import GGUFFileLoader
|
from ..gguf import GGUFFileLoader
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
||||||
|
|
||||||
|
|
||||||
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
||||||
|
low_bit='sym_int4'):
|
||||||
# mixtral enjoys a general architecture of llma
|
# mixtral enjoys a general architecture of llma
|
||||||
# e.g. it applies llama tokenizer
|
# e.g. it applies llama tokenizer
|
||||||
config = loader.config
|
config = loader.config
|
||||||
|
|
@ -33,6 +36,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
n_head = config['llama.attention.head_count']
|
n_head = config['llama.attention.head_count']
|
||||||
n_head_kv = config['llama.attention.head_count_kv']
|
n_head_kv = config['llama.attention.head_count_kv']
|
||||||
hidden_size = config['llama.embedding_length']
|
hidden_size = config['llama.embedding_length']
|
||||||
|
qtype = ggml_tensor_qtype[low_bit]
|
||||||
|
|
||||||
mixtral_config = MixtralConfig(
|
mixtral_config = MixtralConfig(
|
||||||
vocab_size=len(config['tokenizer.ggml.tokens']),
|
vocab_size=len(config['tokenizer.ggml.tokens']),
|
||||||
|
|
@ -81,6 +85,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
"cpu",
|
"cpu",
|
||||||
tensor,
|
tensor,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
|
||||||
|
|
||||||
tensor_loader = loader.tensor_loader
|
tensor_loader = loader.tensor_loader
|
||||||
tensor_loader.load_while_process(process_mixtral)
|
tensor_loader.load_while_process(process_mixtral)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue