fix python style (#9742)

* fix python style

* fix

* fix
This commit is contained in:
Heyang Sun 2023-12-21 11:25:05 +08:00 committed by GitHub
parent b06a3146c8
commit df775cf316

View file

@ -17,7 +17,7 @@
import os import os
import torch import torch
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device as fill_model
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer from transformers import MixtralConfig, MixtralForCausalLM, LlamaTokenizer
@ -57,11 +57,17 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
module_name = get_mixtral_module_name(name) module_name = get_mixtral_module_name(name)
if 'ffn_gate_inp' in name: if 'ffn_gate_inp' in name:
# gguf weight needs to reshape for ffn_gate_inp # gguf weight needs to reshape for ffn_gate_inp
set_module_tensor_to_device(model, module_name, "cpu", \ fill_model(model,
tensor.reshape(num_local_experts, hidden_size), dtype=dtype) module_name,
"cpu",
tensor.reshape(num_local_experts, hidden_size),
dtype=dtype)
else: else:
set_module_tensor_to_device(model, module_name, "cpu", \ fill_model(model,
tensor, dtype=dtype) module_name,
"cpu",
tensor,
dtype=dtype)
tensor_loader = loader.tensor_loader tensor_loader = loader.tensor_loader
tensor_loader.load_while_process(process_mixtral) tensor_loader.load_while_process(process_mixtral)
@ -83,6 +89,7 @@ def load_gguf_mixtral(loader: GGUFFileLoader, dtype: torch.dtype = torch.float):
return model, tokenizer return model, tokenizer
def get_mixtral_module_name(name): def get_mixtral_module_name(name):
if name == 'token_embd.weight': if name == 'token_embd.weight':
return 'model.embed_tokens.weight' return 'model.embed_tokens.weight'
@ -115,4 +122,3 @@ def get_mixtral_module_name(name):
if 'ffn_up' in name: if 'ffn_up' in name:
return f'model.layers.{layer_id}.' + \ return f'model.layers.{layer_id}.' + \
f'block_sparse_moe.experts.{local_expert_id}.w3.weight' f'block_sparse_moe.experts.{local_expert_id}.w3.weight'