diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py index 8add59bb..2c4ace53 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mistral.py @@ -22,9 +22,11 @@ from tempfile import NamedTemporaryFile from transformers import MistralConfig, MistralForCausalLM, LlamaTokenizer from ..gguf import GGUFFileLoader +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from bigdl.llm.transformers.convert import replace_with_low_bit_linear_for_module - -def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): +def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float, + low_bit='sym_int4'): config = loader.config mistral_config = MistralConfig( @@ -44,42 +46,41 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): pretraining_tp=1, ) - ckpt = loader.tensors(dtype) + qtype = ggml_tensor_qtype[low_bit] n_head = config['llama.attention.head_count'] n_head_kv = config['llama.attention.head_count_kv'] - ckpt = restore_mistral_weight(ckpt, n_head, n_head_kv) - - state_dict = {} - state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight'] - state_dict['model.norm.weight'] = ckpt['output_norm.weight'] - state_dict['lm_head.weight'] = ckpt['output.weight'] - for i in range(config['llama.block_count']): - state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \ - ckpt[f'blk.{i}.attn_q.weight'] - state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \ - ckpt[f'blk.{i}.attn_k.weight'] - state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \ - ckpt[f'blk.{i}.attn_v.weight'] - state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \ - ckpt[f'blk.{i}.attn_output.weight'] - state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_gate.weight'] - state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_up.weight'] - state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = \ - ckpt[f'blk.{i}.ffn_down.weight'] - state_dict[f'model.layers.{i}.input_layernorm.weight'] = \ - ckpt[f'blk.{i}.attn_norm.weight'] - state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \ - ckpt[f'blk.{i}.ffn_norm.weight'] with init_empty_weights(): model = MistralForCausalLM(mistral_config) - for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) - - model = model.cpu() + def process_mistral(name, tensor): + nonlocal model + module_name = get_mistral_module_name(name) + if name.endswith("attn_q.weight"): + # gguf weight needs to reshape for q_proj + head, hd_size = tensor.shape[0], tensor.shape[1:] + set_module_tensor_to_device(model, module_name, "cpu", + tensor.reshape(n_head, head // n_head // 2, 2, *hd_size) + .swapaxes(1, 2) + .reshape(tensor.shape), + dtype=dtype) + elif name.endswith("attn_k.weight"): + # gguf weight needs to reshape for k_proj + head, hd_size = tensor.shape[0], tensor.shape[1:] + set_module_tensor_to_device(model, module_name, "cpu", + tensor.reshape(n_head_kv, + head // n_head_kv // 2, + 2, + *hd_size) + .swapaxes(1, 2) + .reshape(tensor.shape), + dtype=dtype) + else: + set_module_tensor_to_device(model, module_name, "cpu", tensor, dtype=dtype) + model = replace_with_low_bit_linear_for_module(model, qtype=qtype, module_name=module_name) + + tensor_loader = loader.tensor_loader + tensor_loader.load_while_process(process_mistral) # see https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto from transformers.convert_slow_tokenizer import import_protobuf @@ -100,18 +101,29 @@ def load_gguf_mistral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): return model, tokenizer -def restore_mistral_weight(ckpt: dict, n_head: int, n_head_kv: int): - # see https://github.com/ggerganov/llama.cpp/blob - # /3e73d31d9cc0232882ce61c64742aff3ecfec416/convert.py#L978 - - for name, weight in ckpt.items(): - head, hd_size = weight.shape[0], weight.shape[1:] - if name.endswith("attn_q.weight"): - ckpt[name] = (weight.reshape(n_head, head // n_head // 2, 2, *hd_size) - .swapaxes(1, 2) - .reshape(weight.shape)) - elif name.endswith("attn_k.weight"): - ckpt[name] = (weight.reshape(n_head_kv, head // n_head_kv // 2, 2, *hd_size) - .swapaxes(1, 2) - .reshape(weight.shape)) - return ckpt +def get_mistral_module_name(name): + if name == 'token_embd.weight': + return 'model.embed_tokens.weight' + if name == 'output_norm.weight': + return 'model.norm.weight' + if name == 'output.weight': + return 'lm_head.weight' + layer_id = name.split('.')[1] + if 'attn_q' in name: + return f'model.layers.{layer_id}.self_attn.q_proj.weight' + if 'attn_k' in name: + return f'model.layers.{layer_id}.self_attn.k_proj.weight' + if 'attn_v' in name: + return f'model.layers.{layer_id}.self_attn.v_proj.weight' + if 'attn_output' in name: + return f'model.layers.{layer_id}.self_attn.o_proj.weight' + if 'ffn_gate' in name: + return f'model.layers.{layer_id}.mlp.gate_proj.weight' + if 'ffn_up' in name: + return f'model.layers.{layer_id}.mlp.up_proj.weight' + if 'ffn_down' in name: + return f'model.layers.{layer_id}.mlp.down_proj.weight' + if 'attn_norm' in name: + return f'model.layers.{layer_id}.input_layernorm.weight' + if 'ffn_norm' in name: + return f'model.layers.{layer_id}.post_attention_layernorm.weight'