Optimize gguf load memory for mistral (#9923)

* optimize gguf load for mistral

* fix output of gguf mistral

* reset
This commit is contained in:
Lilac09 2024-01-19 09:14:39 +08:00 committed by GitHub
parent 9a46f019d7
commit 7032a2ad73

View file

@ -22,9 +22,11 @@ from tempfile import NamedTemporaryFile
from transformers import MistralConfig, MistralForCausalLM, LlamaTokenizer
from ..gguf import GGUFFileLoader
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
low_bit='sym_int4'):
config = loader.config
mistral_config = MistralConfig(
@ -44,42 +46,41 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
pretraining_tp=1,
)
ckpt = loader.tensors(dtype)
qtype = ggml_tensor_qtype[low_bit]
n_head = config['llama.attention.head_count']
n_head_kv = config['llama.attention.head_count_kv']
ckpt = restore_mistral_weight(ckpt, n_head, n_head_kv)
state_dict = {}
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
state_dict['lm_head.weight'] = ckpt['output.weight']
for i in range(config['llama.block_count']):
state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
ckpt[f'blk.{i}.attn_q.weight']
state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
ckpt[f'blk.{i}.attn_k.weight']
state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
ckpt[f'blk.{i}.attn_v.weight']
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
ckpt[f'blk.{i}.attn_output.weight']
state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \
ckpt[f'blk.{i}.ffn_gate.weight']
state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \
ckpt[f'blk.{i}.ffn_up.weight']
state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \
ckpt[f'blk.{i}.ffn_down.weight']
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
ckpt[f'blk.{i}.attn_norm.weight']
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
ckpt[f'blk.{i}.ffn_norm.weight']
with init_empty_weights():
model = MistralForCausalLM(mistral_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
def process_mistral(name, tensor):
nonlocal model
module_name = get_mistral_module_name(name)
if name.endswith("attn_q.weight"):
# gguf weight needs to reshape for q_proj
head, hd_size = tensor.shape[0], tensor.shape[1:]
set_module_tensor_to_device(model, module_name, "cpu",
tensor.reshape(n_head, head // n_head // 2, 2, *hd_size)
.swapaxes(1, 2)
.reshape(tensor.shape),
dtype=dtype)
elif name.endswith("attn_k.weight"):
# gguf weight needs to reshape for k_proj
head, hd_size = tensor.shape[0], tensor.shape[1:]
set_module_tensor_to_device(model, module_name, "cpu",
tensor.reshape(n_head_kv,
head // n_head_kv // 2,
2,
*hd_size)
.swapaxes(1, 2)
.reshape(tensor.shape),
dtype=dtype)
else:
set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
model = model.cpu()
tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mistral)
# see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
from transformers.convert_slow_tokenizer import import_protobuf
@ -100,18 +101,29 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
return model, tokenizer
def restore_mistral_weight(ckpt: dict, n_head: int, n_head_kv: int):
# see https://github.com/ggerganov/llama.cpp/blob
# /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978
for name, weight in ckpt.items():
head, hd_size = weight.shape[0], weight.shape[1:]
if name.endswith("attn_q.weight"):
ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size)
.swapaxes(1, 2)
.reshape(weight.shape))
elif name.endswith("attn_k.weight"):
ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size)
.swapaxes(1, 2)
.reshape(weight.shape))
return ckpt
def get_mistral_module_name(name):
if name == 'token_embd.weight':
return 'model.embed_tokens.weight'
if name == 'output_norm.weight':
return 'model.norm.weight'
if name == 'output.weight':
return 'lm_head.weight'
layer_id = name.split('.')[1]
if 'attn_q' in name:
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
if 'attn_k' in name:
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
if 'attn_v' in name:
return f'model.layers.{layer_id}.self_attn.v_proj.weight'
if 'attn_output' in name:
return f'model.layers.{layer_id}.self_attn.o_proj.weight'
if 'ffn_gate' in name:
return f'model.layers.{layer_id}.mlp.gate_proj.weight'
if 'ffn_up' in name:
return f'model.layers.{layer_id}.mlp.up_proj.weight'
if 'ffn_down' in name:
return f'model.layers.{layer_id}.mlp.down_proj.weight'
if 'attn_norm' in name:
return f'model.layers.{layer_id}.input_layernorm.weight'
if 'ffn_norm' in name:
return f'model.layers.{layer_id}.post_attention_layernorm.weight'