diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py index fcaef116..49e81d48 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/baichuan.py @@ -24,9 +24,12 @@ from .model_implement.baichuan.modeling_baichuan import BaiChuanForCausalLM from .model_implement.baichuan.tokenization_baichuan import BaiChuanTokenizer from ..gguf import GGUFFileLoader +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module -def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): +def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, + low_bit='sym_int4'): config = loader.config baichuan_config = BaiChuanConfig( @@ -46,43 +49,36 @@ def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float) pretraining_tp=1, ) - ckpt = loader.tensors(dtype) + qtype = ggml_tensor_qtype[low_bit] n_head = config['baichuan.attention.head_count'] n_head_kv = config['baichuan.attention.head_count_kv'] - ckpt = restore_baichuan_weight(ckpt, n_head, n_head_kv) - - state_dict = {} - state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight'] - state_dict['model.norm.weight'] = ckpt['output_norm.weight'] - state_dict['lm_head.weight'] = ckpt['output.weight'] - for i in range(config['baichuan.block_count']): - # rebuild W_pack - a = ckpt[f'blk.{i}.attn_q.weight'] - b = ckpt[f'blk.{i}.attn_k.weight'] - c = ckpt[f'blk.{i}.attn_v.weight'] - d = torch.cat([a, b, c], dim=0) - state_dict[f'model.layers.{i}.self_attn.W_pack.weight'] = d - - state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \ - ckpt[f'blk.{i}.attn_output.weight'] - state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_gate.weight'] - state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_up.weight'] - state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_down.weight'] - state_dict[f'model.layers.{i}.input_layernorm.weight'] = \ - ckpt[f'blk.{i}.attn_norm.weight'] - state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \ - ckpt[f'blk.{i}.ffn_norm.weight'] with init_empty_weights(): model = BaiChuanForCausalLM(baichuan_config) - for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) + attn_q_tensor, attn_k_tensor, attn_v_tensor = [torch.tensor([]) for _ in range(3)] - model = model.cpu() + def process_baichuan(name, tensor): + nonlocal model, attn_q_tensor, attn_k_tensor, attn_v_tensor + module_name = get_baichuan_module_name(name) + tensor = restore_baichuan_weight(name, tensor, n_head, n_head_kv) + + if 'attn_q' in name: + attn_q_tensor = tensor + return + if 'attn_k' in name: + attn_k_tensor = tensor + return + if 'attn_v' in name: + attn_v_tensor = tensor + tensor = torch.cat([attn_q_tensor, attn_k_tensor, attn_v_tensor], dim=0) + set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype) + if 'lm_head' in module_name: + return + model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) + + tensor_loader = loader.tensor_loader + tensor_loader.load_while_process(process_baichuan) # see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto from transformers.convert_slow_tokenizer import import_protobuf @@ -103,21 +99,44 @@ def load_gguf_baichuan(loader: GGUFFileLoader, dtype: torch.dtype = torch.float) return model, tokenizer -def restore_baichuan_weight(ckpt: dict, n_head: int, n_head_kv: int): +def restore_baichuan_weight(name, weight, n_head: int, n_head_kv: int): # see https://github.com/ggerganov/llama.cpp/blob/master/convert-hf-to-gguf.py#L535 - for name, weight in ckpt.items(): - head, hd_size = weight.shape[0], weight.shape[1:] - if n_head != n_head_kv: - new_n_head = n_head // n_head_kv - else: - new_n_head = n_head - if name.endswith("attn_q.weight"): - ckpt[name] = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size) - .swapaxes(1, 2) - .reshape(weight.shape)) - elif name.endswith("attn_k.weight"): - ckpt[name] = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size) - .swapaxes(1, 2) - .reshape(weight.shape)) - return ckpt + head, hd_size = weight.shape[0], weight.shape[1:] + if n_head != n_head_kv: + new_n_head = n_head // n_head_kv + else: + new_n_head = n_head + if name.endswith("attn_q.weight"): + weight = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size) + .swapaxes(1, 2) + .reshape(weight.shape)) + elif name.endswith("attn_k.weight"): + weight = (weight.reshape(new_n_head, head // new_n_head // 2, 2, *hd_size) + .swapaxes(1, 2) + .reshape(weight.shape)) + return weight + + +def get_baichuan_module_name(name): + if name == 'token_embd.weight': + return 'model.embed_tokens.weight' + if name == 'output_norm.weight': + return 'model.norm.weight' + if name == 'output.weight': + return 'lm_head.weight' + layer_id = name.split('.')[1] + if 'attn_q' in name or 'attn_k' in name or 'attn_v' in name: + return f'model.layers.{layer_id}.self_attn.W_pack.weight' + if 'attn_output' in name: + return f'model.layers.{layer_id}.self_attn.o_proj.weight' + if 'ffn_gate' in name: + return f'model.layers.{layer_id}.mlp.gate_proj.weight' + if 'ffn_up' in name: + return f'model.layers.{layer_id}.mlp.up_proj.weight' + if 'ffn_down' in name: + return f'model.layers.{layer_id}.mlp.down_proj.weight' + if 'attn_norm' in name: + return f'model.layers.{layer_id}.input_layernorm.weight' + if 'ffn_norm' in name: + return f'model.layers.{layer_id}.post_attention_layernorm.weight'