gguf memory optimization for baichuan (#9937)
This commit is contained in:
parent
2e1448f08e
commit
9a46f019d7
1 changed files with 66 additions and 47 deletions
|
|
@ -24,9 +24,12 @@ from .model_implement.baichuan.modeling_baichuan import BaiChuanForCausalLM
|
|||
from .model_implement.baichuan.tokenization_baichuan import BaiChuanTokenizer
|
||||
|
||||
from ..gguf import GGUFFileLoader
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module
|
||||
|
||||
|
||||
def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||
def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float,
|
||||
low_bit='sym_int4'):
|
||||
config = loader.config
|
||||
|
||||
baichuan_config = BaiChuanConfig(
|
||||
|
|
@ -46,43 +49,36 @@ def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float)
|
|||
pretraining_tp=1,
|
||||
)
|
||||
|
||||
ckpt = loader.tensors(dtype)
|
||||
qtype = ggml_tensor_qtype[low_bit]
|
||||
n_head = config['baichuan.attention.head_count']
|
||||
n_head_kv = config['baichuan.attention.head_count_kv']
|
||||
ckpt = restore_baichuan_weight(ckpt, n_head, n_head_kv)
|
||||
|
||||
state_dict = {}
|
||||
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
|
||||
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
|
||||
state_dict['lm_head.weight'] = ckpt['output.weight']
|
||||
for i in range(config['baichuan.block_count']):
|
||||
# rebuild W_pack
|
||||
a = ckpt[f'blk.{i}.attn_q.weight']
|
||||
b = ckpt[f'blk.{i}.attn_k.weight']
|
||||
c = ckpt[f'blk.{i}.attn_v.weight']
|
||||
d = torch.cat([a, b, c], dim=0)
|
||||
state_dict[f'model.layers.{i}.self_attn.W_pack.weight'] = d
|
||||
|
||||
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_output.weight']
|
||||
state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_gate.weight']
|
||||
state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_up.weight']
|
||||
state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_down.weight']
|
||||
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
|
||||
ckpt[f'blk.{i}.attn_norm.weight']
|
||||
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
|
||||
ckpt[f'blk.{i}.ffn_norm.weight']
|
||||
|
||||
with init_empty_weights():
|
||||
model = BaiChuanForCausalLM(baichuan_config)
|
||||
|
||||
for name, weight in state_dict.items():
|
||||
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
|
||||
attn_q_tensor, attn_k_tensor, attn_v_tensor = [torch.tensor([]) for _ in range(3)]
|
||||
|
||||
model = model.cpu()
|
||||
def process_baichuan(name, tensor):
|
||||
nonlocal model, attn_q_tensor, attn_k_tensor, attn_v_tensor
|
||||
module_name = get_baichuan_module_name(name)
|
||||
tensor = restore_baichuan_weight(name, tensor, n_head, n_head_kv)
|
||||
|
||||
if 'attn_q' in name:
|
||||
attn_q_tensor = tensor
|
||||
return
|
||||
if 'attn_k' in name:
|
||||
attn_k_tensor = tensor
|
||||
return
|
||||
if 'attn_v' in name:
|
||||
attn_v_tensor = tensor
|
||||
tensor = torch.cat([attn_q_tensor, attn_k_tensor, attn_v_tensor], dim=0)
|
||||
set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype)
|
||||
if 'lm_head' in module_name:
|
||||
return
|
||||
model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name)
|
||||
|
||||
tensor_loader = loader.tensor_loader
|
||||
tensor_loader.load_while_process(process_baichuan)
|
||||
|
||||
# see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
|
||||
from transformers.convert_slow_tokenizer import import_protobuf
|
||||
|
|
@ -103,21 +99,44 @@ def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float)
|
|||
return model, tokenizer
|
||||
|
||||
|
||||
def restore_baichuan_weight(ckpt: dict, n_head: int, n_head_kv: int):
|
||||
def restore_baichuan_weight(name, weight, n_head: int, n_head_kv: int):
|
||||
# see https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py#L535
|
||||
|
||||
for name, weight in ckpt.items():
|
||||
head, hd_size = weight.shape[0], weight.shape[1:]
|
||||
if n_head != n_head_kv:
|
||||
new_n_head = n_head // n_head_kv
|
||||
else:
|
||||
new_n_head = n_head
|
||||
if name.endswith("attn_q.weight"):
|
||||
ckpt[name] = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size)
|
||||
weight = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weight.shape))
|
||||
elif name.endswith("attn_k.weight"):
|
||||
ckpt[name] = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size)
|
||||
weight = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weight.shape))
|
||||
return ckpt
|
||||
return weight
|
||||
|
||||
|
||||
def get_baichuan_module_name(name):
|
||||
if name == 'token_embd.weight':
|
||||
return 'model.embed_tokens.weight'
|
||||
if name == 'output_norm.weight':
|
||||
return 'model.norm.weight'
|
||||
if name == 'output.weight':
|
||||
return 'lm_head.weight'
|
||||
layer_id = name.split('.')[1]
|
||||
if 'attn_q' in name or 'attn_k' in name or 'attn_v' in name:
|
||||
return f'model.layers.{layer_id}.self_attn.W_pack.weight'
|
||||
if 'attn_output' in name:
|
||||
return f'model.layers.{layer_id}.self_attn.o_proj.weight'
|
||||
if 'ffn_gate' in name:
|
||||
return f'model.layers.{layer_id}.mlp.gate_proj.weight'
|
||||
if 'ffn_up' in name:
|
||||
return f'model.layers.{layer_id}.mlp.up_proj.weight'
|
||||
if 'ffn_down' in name:
|
||||
return f'model.layers.{layer_id}.mlp.down_proj.weight'
|
||||
if 'attn_norm' in name:
|
||||
return f'model.layers.{layer_id}.input_layernorm.weight'
|
||||
if 'ffn_norm' in name:
|
||||
return f'model.layers.{layer_id}.post_attention_layernorm.weight'
|
||||
|
|
|
|||
Loading…
Reference in a new issue