Load Mixtral gguf in a block-wise way (#9725)
* Load Mixtral gguf in a block-wise way * refine
This commit is contained in:
parent
34bb804189
commit
d157f623b6
2 changed files with 64 additions and 35 deletions
|
|
@ -246,6 +246,24 @@ class GGUFTensorLoader:
|
||||||
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
|
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
|
||||||
yield name, tensor
|
yield name, tensor
|
||||||
|
|
||||||
|
def load_while_process(self, process):
|
||||||
|
with open(self.fpath, 'rb') as f:
|
||||||
|
for name, ndims, dims, qtype, offset in tqdm(self.infos, desc="Loading gguf tensors"):
|
||||||
|
total_ne = functools.reduce(lambda x, y: x * y, dims)
|
||||||
|
invalidInputError(total_ne % self.block_ne[qtype] == 0,
|
||||||
|
f"wrong elements num: {dims}")
|
||||||
|
|
||||||
|
size = total_ne // self.block_ne[qtype] * self.block_size[qtype]
|
||||||
|
invalidInputError(size != 0, f"unsupported quantize type: {qtype}")
|
||||||
|
|
||||||
|
offset += self.base_offset
|
||||||
|
f.seek(offset)
|
||||||
|
data = f.read(size)
|
||||||
|
arr = numpy.frombuffer(data, dtype=numpy.uint8)
|
||||||
|
tensor = torch.from_numpy(arr)
|
||||||
|
tensor = self.convert_funcs[qtype](tensor, size, ndims, dims)
|
||||||
|
process(name, tensor)
|
||||||
|
|
||||||
def convert_f32_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
def convert_f32_tensor(self, tensor: torch.Tensor, size: int, ndims: int, dims: int):
|
||||||
return tensor.view(torch.float)
|
return tensor.view(torch.float)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,44 +50,21 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
num_experts_per_tok=num_experts_per_tok,
|
num_experts_per_tok=num_experts_per_tok,
|
||||||
)
|
)
|
||||||
|
|
||||||
ckpt = loader.tensors(dtype)
|
|
||||||
from .llama import restore_llama_weight
|
|
||||||
ckpt = restore_llama_weight(ckpt, n_head, n_head_kv)
|
|
||||||
|
|
||||||
state_dict = {}
|
|
||||||
state_dict['model.embed_tokens.weight'] = ckpt['token_embd.weight']
|
|
||||||
state_dict['model.norm.weight'] = ckpt['output_norm.weight']
|
|
||||||
state_dict['lm_head.weight'] = ckpt['output.weight']
|
|
||||||
for i in range(config['llama.block_count']):
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_q.weight']
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_k.weight']
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_v.weight']
|
|
||||||
state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_output.weight']
|
|
||||||
state_dict[f'model.layers.{i}.input_layernorm.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.attn_norm.weight']
|
|
||||||
state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_norm.weight']
|
|
||||||
state_dict[f'model.layers.{i}.block_sparse_moe.gate.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_gate_inp.weight'].reshape(num_local_experts, hidden_size)
|
|
||||||
for j in range(num_local_experts):
|
|
||||||
state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight'] = \
|
|
||||||
(ckpt[f'blk.{i}.ffn_gate.{j}.weight'])
|
|
||||||
state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_down.{j}.weight']
|
|
||||||
state_dict[f'model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight'] = \
|
|
||||||
ckpt[f'blk.{i}.ffn_up.{j}.weight']
|
|
||||||
|
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = MixtralForCausalLM(mixtral_config)
|
model = MixtralForCausalLM(mixtral_config)
|
||||||
|
|
||||||
for name, weight in state_dict.items():
|
def process_mixtral(name, tensor):
|
||||||
set_module_tensor_to_device(model, name, "cpu", weight, dtype=dtype)
|
module_name = get_mixtral_module_name(name)
|
||||||
|
if 'ffn_gate_inp' in name:
|
||||||
model = model.cpu()
|
# gguf weight needs to reshape for ffn_gate_inp
|
||||||
|
set_module_tensor_to_device(model, module_name, "cpu", \
|
||||||
|
tensor.reshape(num_local_experts, hidden_size), dtype=dtype)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(model, module_name, "cpu", \
|
||||||
|
tensor, dtype=dtype)
|
||||||
|
|
||||||
|
tensor_loader = loader.tensor_loader
|
||||||
|
tensor_loader.load_while_process(process_mixtral)
|
||||||
|
|
||||||
from transformers.convert_slow_tokenizer import import_protobuf
|
from transformers.convert_slow_tokenizer import import_protobuf
|
||||||
spm_pb2 = import_protobuf("Failed to import protobuf")
|
spm_pb2 = import_protobuf("Failed to import protobuf")
|
||||||
|
|
@ -105,3 +82,37 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
os.remove(f.name)
|
os.remove(f.name)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
def get_mixtral_module_name(name):
|
||||||
|
if name == 'token_embd.weight':
|
||||||
|
return 'model.embed_tokens.weight'
|
||||||
|
if name == 'output_norm.weight':
|
||||||
|
return 'model.norm.weight'
|
||||||
|
if name == 'output.weight':
|
||||||
|
return 'lm_head.weight'
|
||||||
|
layer_id = name.split('.')[1]
|
||||||
|
if 'attn_q' in name:
|
||||||
|
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
|
||||||
|
if 'attn_k' in name:
|
||||||
|
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
|
||||||
|
if 'attn_v' in name:
|
||||||
|
return f'model.layers.{layer_id}.self_attn.v_proj.weight'
|
||||||
|
if 'attn_output' in name:
|
||||||
|
return f'model.layers.{layer_id}.self_attn.o_proj.weight'
|
||||||
|
if 'attn_norm' in name:
|
||||||
|
return f'model.layers.{layer_id}.input_layernorm.weight'
|
||||||
|
if 'ffn_norm' in name:
|
||||||
|
return f'model.layers.{layer_id}.post_attention_layernorm.weight'
|
||||||
|
if 'ffn_gate_inp' in name:
|
||||||
|
return f'model.layers.{layer_id}.block_sparse_moe.gate.weight'
|
||||||
|
local_expert_id = name.split('.')[3]
|
||||||
|
if 'ffn_gate' in name:
|
||||||
|
return f'model.layers.{layer_id}.' + \
|
||||||
|
f'block_sparse_moe.experts.{local_expert_id}.w1.weight'
|
||||||
|
if 'ffn_down' in name:
|
||||||
|
return f'model.layers.{layer_id}.' + \
|
||||||
|
f'block_sparse_moe.experts.{local_expert_id}.w2.weight'
|
||||||
|
if 'ffn_up' in name:
|
||||||
|
return f'model.layers.{layer_id}.' + \
|
||||||
|
f'block_sparse_moe.experts.{local_expert_id}.w3.weight'
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue