parent
b06a3146c8
commit
df775cf316
1 changed files with 14 additions and 8 deletions
|
|
@ -17,7 +17,7 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device as fill_model
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
|
|
@ -57,11 +57,17 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
module_name = get_mixtral_module_name(name)
|
module_name = get_mixtral_module_name(name)
|
||||||
if 'ffn_gate_inp' in name:
|
if 'ffn_gate_inp' in name:
|
||||||
# gguf weight needs to reshape for ffn_gate_inp
|
# gguf weight needs to reshape for ffn_gate_inp
|
||||||
set_module_tensor_to_device(model, module_name, "cpu", \
|
fill_model(model,
|
||||||
tensor.reshape(num_local_experts, hidden_size), dtype=dtype)
|
module_name,
|
||||||
|
"cpu",
|
||||||
|
tensor.reshape(num_local_experts, hidden_size),
|
||||||
|
dtype=dtype)
|
||||||
else:
|
else:
|
||||||
set_module_tensor_to_device(model, module_name, "cpu", \
|
fill_model(model,
|
||||||
tensor, dtype=dtype)
|
module_name,
|
||||||
|
"cpu",
|
||||||
|
tensor,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
tensor_loader = loader.tensor_loader
|
tensor_loader = loader.tensor_loader
|
||||||
tensor_loader.load_while_process(process_mixtral)
|
tensor_loader.load_while_process(process_mixtral)
|
||||||
|
|
@ -83,6 +89,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_mixtral_module_name(name):
|
def get_mixtral_module_name(name):
|
||||||
if name == 'token_embd.weight':
|
if name == 'token_embd.weight':
|
||||||
return 'model.embed_tokens.weight'
|
return 'model.embed_tokens.weight'
|
||||||
|
|
@ -92,7 +99,7 @@ def get_mixtral_module_name(name):
|
||||||
return 'lm_head.weight'
|
return 'lm_head.weight'
|
||||||
layer_id = name.split('.')[1]
|
layer_id = name.split('.')[1]
|
||||||
if 'attn_q' in name:
|
if 'attn_q' in name:
|
||||||
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
|
return f'model.layers.{layer_id}.self_attn.q_proj.weight'
|
||||||
if 'attn_k' in name:
|
if 'attn_k' in name:
|
||||||
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
|
return f'model.layers.{layer_id}.self_attn.k_proj.weight'
|
||||||
if 'attn_v' in name:
|
if 'attn_v' in name:
|
||||||
|
|
@ -115,4 +122,3 @@ def get_mixtral_module_name(name):
|
||||||
if 'ffn_up' in name:
|
if 'ffn_up' in name:
|
||||||
return f'model.layers.{layer_id}.' + \
|
return f'model.layers.{layer_id}.' + \
|
||||||
f'block_sparse_moe.experts.{local_expert_id}.w3.weight'
|
f'block_sparse_moe.experts.{local_expert_id}.w3.weight'
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue