Load Mixtral gguf in a block-wise way (#9725)

* Load Mixtral gguf in a block-wise way

* refine
This commit is contained in:
Heyang Sun 2023-12-21 10:03:23 +08:00 committed by GitHub
parent 34bb804189
commit d157f623b6
2 changed files with 64 additions and 35 deletions

View file

@ -246,6 +246,24 @@ class GGUFTensorLoader:
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
yield name, tensor
def load_while_process(self, process):
with open(self.fpath, 'rb') as f:
for name, ndims, dims, qtype, offset in tqdm(self.infos, desc="Loading gguf tensors"):
total_ne = functools.reduce(lambda x, y: x * y, dims)
invalidInputError(total_ne % self.block_ne[qtype] == 0,
f"wrong elements num: {dims}")
size = total_ne // self.block_ne[qtype] * self.block_size[qtype]
invalidInputError(size != 0, f"unsupported quantize type: {qtype}")
offset += self.base_offset
f.seek(offset)
data = f.read(size)
arr = numpy.frombuffer(data, dtype=numpy.uint8)
tensor = torch.from_numpy(arr)
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
process(name, tensor)
def convert_f32_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
return tensor.view(torch.float)

View file

@ -50,44 +50,21 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
num_experts_per_tok=num_experts_per_tok,
)
ckpt = loader.tensors(dtype)
from .llama import restore_llama_weight
ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
state_dict = {}
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
state_dict['lm_head.weight'] = ckpt['output.weight']
for i in range(config['llama.block_count']):
state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
ckpt[f'blk.{i}.attn_q.weight']
state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
ckpt[f'blk.{i}.attn_k.weight']
state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
ckpt[f'blk.{i}.attn_v.weight']
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
ckpt[f'blk.{i}.attn_output.weight']
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
ckpt[f'blk.{i}.attn_norm.weight']
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
ckpt[f'blk.{i}.ffn_norm.weight']
state_dict[f'model.layers.{i}.block_sparse_moe.gate.weight'] = \
ckpt[f'blk.{i}.ffn_gate_inp.weight'].reshape(num_local_experts, hidden_size)
for j in range(num_local_experts):
state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight'] = \
(ckpt[f'blk.{i}.ffn_gate.{j}.weight'])
state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight'] = \
ckpt[f'blk.{i}.ffn_down.{j}.weight']
state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight'] = \
ckpt[f'blk.{i}.ffn_up.{j}.weight']
with init_empty_weights():
model = MixtralForCausalLM(mixtral_config)
for name, weight in state_dict.items():
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
def process_mixtral(name, tensor):
module_name = get_mixtral_module_name(name)
if 'ffn_gate_inp' in name:
# gguf weight needs to reshape for ffn_gate_inp
set_module_tensor_to_device(model, module_name, "cpu", \
tensor.reshape(num_local_experts, hidden_size), dtype=dtype)
else:
set_module_tensor_to_device(model, module_name, "cpu", \
tensor, dtype=dtype)
model = model.cpu()
tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mixtral)
from transformers.convert_slow_tokenizer import import_protobuf
spm_pb2 = import_protobuf("Failed to import protobuf")
@ -105,3 +82,37 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
os.remove(f.name)
return model, tokenizer
def get_mixtral_module_name(name):
if name == 'token_embd.weight':
return 'model.embed_tokens.weight'
if name == 'output_norm.weight':
return 'model.norm.weight'
if name == 'output.weight':
return 'lm_head.weight'
layer_id = name.split('.')[1]
if 'attn_q' in name:
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
if 'attn_k' in name:
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
if 'attn_v' in name:
return f'model.layers.{layer_id}.self_attn.v_proj.weight'
if 'attn_output' in name:
return f'model.layers.{layer_id}.self_attn.o_proj.weight'
if 'attn_norm' in name:
return f'model.layers.{layer_id}.input_layernorm.weight'
if 'ffn_norm' in name:
return f'model.layers.{layer_id}.post_attention_layernorm.weight'
if 'ffn_gate_inp' in name:
return f'model.layers.{layer_id}.block_sparse_moe.gate.weight'
local_expert_id = name.split('.')[3]
if 'ffn_gate' in name:
return f'model.layers.{layer_id}.' + \
f'block_sparse_moe.experts.{local_expert_id}.w1.weight'
if 'ffn_down' in name:
return f'model.layers.{layer_id}.' + \
f'block_sparse_moe.experts.{local_expert_id}.w2.weight'
if 'ffn_up' in name:
return f'model.layers.{layer_id}.' + \
f'block_sparse_moe.experts.{local_expert_id}.w3.weight'