diff --git a/python/llm/src/bigdl/llm/transformers/gguf/gguf.py b/python/llm/src/bigdl/llm/transformers/gguf/gguf.py index 0dcfcdd0..09487828 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/gguf.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/gguf.py @@ -246,6 +246,24 @@ class GGUFTensorLoader: tensor = self.convert_funcs[qtype](tensor, size, ndims, dims) yield name, tensor + def load_while_process(self, process): + with open(self.fpath, 'rb') as f: + for name, ndims, dims, qtype, offset in tqdm(self.infos, desc="Loading gguf tensors"): + total_ne = functools.reduce(lambda x, y: x * y, dims) + invalidInputError(total_ne % self.block_ne[qtype] == 0, + f"wrong elements num: {dims}") + + size = total_ne // self.block_ne[qtype] * self.block_size[qtype] + invalidInputError(size != 0, f"unsupported quantize type: {qtype}") + + offset += self.base_offset + f.seek(offset) + data = f.read(size) + arr = numpy.frombuffer(data, dtype=numpy.uint8) + tensor = torch.from_numpy(arr) + tensor = self.convert_funcs[qtype](tensor, size, ndims, dims) + process(name, tensor) + def convert_f32_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int): return tensor.view(torch.float) diff --git a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py index c7456d55..3d77ac62 100644 --- a/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/gguf/models/mixtral.py @@ -50,44 +50,21 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): num_experts_per_tok=num_experts_per_tok, ) - ckpt = loader.tensors(dtype) - from .llama import restore_llama_weight - ckpt = restore_llama_weight(ckpt, n_head, n_head_kv) - - state_dict = {} - state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight'] - state_dict['model.norm.weight'] = ckpt['output_norm.weight'] - state_dict['lm_head.weight'] = ckpt['output.weight'] - for i in range(config['llama.block_count']): - state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \ - ckpt[f'blk.{i}.attn_q.weight'] - state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \ - ckpt[f'blk.{i}.attn_k.weight'] - state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \ - ckpt[f'blk.{i}.attn_v.weight'] - state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \ - ckpt[f'blk.{i}.attn_output.weight'] - state_dict[f'model.layers.{i}.input_layernorm.weight'] = \ - ckpt[f'blk.{i}.attn_norm.weight'] - state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \ - ckpt[f'blk.{i}.ffn_norm.weight'] - state_dict[f'model.layers.{i}.block_sparse_moe.gate.weight'] = \ - ckpt[f'blk.{i}.ffn_gate_inp.weight'].reshape(num_local_experts, hidden_size) - for j in range(num_local_experts): - state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight'] = \ - (ckpt[f'blk.{i}.ffn_gate.{j}.weight']) - state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight'] = \ - ckpt[f'blk.{i}.ffn_down.{j}.weight'] - state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight'] = \ - ckpt[f'blk.{i}.ffn_up.{j}.weight'] - with init_empty_weights(): model = MixtralForCausalLM(mixtral_config) - for name, weight in state_dict.items(): - set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype) - - model = model.cpu() + def process_mixtral(name, tensor): + module_name = get_mixtral_module_name(name) + if 'ffn_gate_inp' in name: + # gguf weight needs to reshape for ffn_gate_inp + set_module_tensor_to_device(model, module_name, "cpu", \ + tensor.reshape(num_local_experts, hidden_size), dtype=dtype) + else: + set_module_tensor_to_device(model, module_name, "cpu", \ + tensor, dtype=dtype) + + tensor_loader = loader.tensor_loader + tensor_loader.load_while_process(process_mixtral) from transformers.convert_slow_tokenizer import import_protobuf spm_pb2 = import_protobuf("Failed to import protobuf") @@ -105,3 +82,37 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float): os.remove(f.name) return model, tokenizer + +def get_mixtral_module_name(name): + if name == 'token_embd.weight': + return 'model.embed_tokens.weight' + if name == 'output_norm.weight': + return 'model.norm.weight' + if name == 'output.weight': + return 'lm_head.weight' + layer_id = name.split('.')[1] + if 'attn_q' in name: + return f'model.layers.{layer_id}.self_attn.q_proj.weight' + if 'attn_k' in name: + return f'model.layers.{layer_id}.self_attn.k_proj.weight' + if 'attn_v' in name: + return f'model.layers.{layer_id}.self_attn.v_proj.weight' + if 'attn_output' in name: + return f'model.layers.{layer_id}.self_attn.o_proj.weight' + if 'attn_norm' in name: + return f'model.layers.{layer_id}.input_layernorm.weight' + if 'ffn_norm' in name: + return f'model.layers.{layer_id}.post_attention_layernorm.weight' + if 'ffn_gate_inp' in name: + return f'model.layers.{layer_id}.block_sparse_moe.gate.weight' + local_expert_id = name.split('.')[3] + if 'ffn_gate' in name: + return f'model.layers.{layer_id}.' + \ + f'block_sparse_moe.experts.{local_expert_id}.w1.weight' + if 'ffn_down' in name: + return f'model.layers.{layer_id}.' + \ + f'block_sparse_moe.experts.{local_expert_id}.w2.weight' + if 'ffn_up' in name: + return f'model.layers.{layer_id}.' + \ + f'block_sparse_moe.experts.{local_expert_id}.w3.weight' +