[NPU] support asym_int4 for minicpm (#12567)
This commit is contained in:
parent
6e801bc4e1
commit
1a2ab12876
2 changed files with 146 additions and 40 deletions
|
|
@ -81,7 +81,8 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
num_hidden_layers,
|
num_hidden_layers,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -90,7 +91,8 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
device=device,
|
device=device,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size)
|
group_size=group_size,
|
||||||
|
asym=asym)
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
@ -272,7 +274,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
do_print: bool = False,
|
do_print: bool = False,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -280,8 +283,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
|
|
||||||
op_parameters = []
|
op_parameters = []
|
||||||
for w in parameters:
|
for w in parameters:
|
||||||
if isinstance(w, tuple): # from QuantizedLinear
|
if isinstance(w, tuple) and not asym: # from QuantizedLinear
|
||||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||||
|
elif isinstance(w, tuple) and asym: # from QuantizedLinear
|
||||||
|
op_parameters.append((w[0].numpy(), w[1].numpy(), w[2].numpy()))
|
||||||
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||||
op_parameters.append(w.numpy())
|
op_parameters.append(w.numpy())
|
||||||
elif isinstance(w, np.ndarray): # scale
|
elif isinstance(w, np.ndarray): # scale
|
||||||
|
|
@ -336,7 +341,8 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym,
|
||||||
)
|
)
|
||||||
self.backend_decoders.append(decoder)
|
self.backend_decoders.append(decoder)
|
||||||
|
|
||||||
|
|
@ -414,7 +420,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
n_splits_linear: int = 1,
|
n_splits_linear: int = 1,
|
||||||
n_splits_down_proj: int = 1,
|
n_splits_down_proj: int = 1,
|
||||||
group_size: int = 0
|
group_size: int = 0,
|
||||||
|
asym: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_parameters = parameters
|
self.op_parameters = parameters
|
||||||
|
|
@ -447,7 +454,8 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym,
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
|
@ -534,6 +542,7 @@ def run_decode(
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -546,9 +555,16 @@ def run_decode(
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -580,7 +596,8 @@ def run_decode(
|
||||||
do_print=False,
|
do_print=False,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym,
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
@ -753,6 +770,7 @@ def run_prefill(
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
|
|
@ -765,9 +783,16 @@ def run_prefill(
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -793,7 +818,8 @@ def run_prefill(
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights.extend(weights)
|
layer_weights.extend(weights)
|
||||||
|
|
|
||||||
|
|
@ -105,6 +105,7 @@ class MiniCPMLMHead(LLMBaseNNFactory):
|
||||||
profile: bool = False,
|
profile: bool = False,
|
||||||
device: str = "NPU",
|
device: str = "NPU",
|
||||||
n_splits: int = 1,
|
n_splits: int = 1,
|
||||||
|
asym: bool = False
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -134,11 +135,13 @@ class MiniCPMLMHead(LLMBaseNNFactory):
|
||||||
# for MiniCPM-2B-sft-bf16
|
# for MiniCPM-2B-sft-bf16
|
||||||
hidden_states_1 = self.linear(
|
hidden_states_1 = self.linear(
|
||||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
||||||
n_splits=n_splits, scale_factor=(n_splits == 1)
|
n_splits=n_splits, scale_factor=(n_splits == 1),
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
hidden_states_2 = self.linear(
|
hidden_states_2 = self.linear(
|
||||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
||||||
n_splits=n_splits, scale_factor=(n_splits == 1)
|
n_splits=n_splits, scale_factor=(n_splits == 1),
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
|
hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
|
||||||
|
|
@ -147,7 +150,8 @@ class MiniCPMLMHead(LLMBaseNNFactory):
|
||||||
# for MiniCPM-1B-sft-bf16
|
# for MiniCPM-1B-sft-bf16
|
||||||
hidden_states = self.linear(
|
hidden_states = self.linear(
|
||||||
hidden_states, self.vocab_size, self.hidden_size, bias=False,
|
hidden_states, self.vocab_size, self.hidden_size, bias=False,
|
||||||
wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1)
|
wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1),
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
|
|
||||||
# define outputs
|
# define outputs
|
||||||
|
|
@ -165,26 +169,46 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
vocab_size = model.config.vocab_size
|
vocab_size = model.config.vocab_size
|
||||||
model_norm = model.model.norm
|
model_norm = model.model.norm
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
if n_splits_linear == 1:
|
if n_splits_linear == 1:
|
||||||
if vocab_size == 122753:
|
if vocab_size == 122753:
|
||||||
# for MiniCPM-2B-sft-bf16
|
# for MiniCPM-2B-sft-bf16
|
||||||
|
asym = model.lm_head_0.qtype == "asym_int4_rtn"
|
||||||
|
if asym:
|
||||||
|
weights = [(model.lm_head_0.weight, model.lm_head_0.scale, model.lm_head_0.zero),
|
||||||
|
(model.lm_head_1.weight, model.lm_head_1.scale, model.lm_head_1.zero)]
|
||||||
|
else:
|
||||||
weights = [(model.lm_head_0.weight, model.lm_head_0.scale),
|
weights = [(model.lm_head_0.weight, model.lm_head_0.scale),
|
||||||
(model.lm_head_1.weight, model.lm_head_1.scale)]
|
(model.lm_head_1.weight, model.lm_head_1.scale)]
|
||||||
else:
|
else:
|
||||||
# for MiniCPM-1B-sft-bf16
|
# for MiniCPM-1B-sft-bf16
|
||||||
|
asym = model.lm_head.qtype == "asym_int4_rtn"
|
||||||
|
if asym:
|
||||||
|
weights = [(model.lm_head.weight, model.lm_head.scale, model.lm_head.zero)]
|
||||||
|
else:
|
||||||
weights = [(model.lm_head.weight, model.lm_head.scale)]
|
weights = [(model.lm_head.weight, model.lm_head.scale)]
|
||||||
else:
|
else:
|
||||||
weights = []
|
weights = []
|
||||||
if vocab_size == 122753:
|
if vocab_size == 122753:
|
||||||
|
asym = model.lm_head_0.lm_heads[0].qtype == "asym_int4_rtn"
|
||||||
lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads]
|
lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads]
|
||||||
else:
|
else:
|
||||||
|
asym = model.lm_head.lm_heads[0].qtype == "asym_int4_rtn"
|
||||||
lm_head_list = [model.lm_head.lm_heads]
|
lm_head_list = [model.lm_head.lm_heads]
|
||||||
for lh in lm_head_list:
|
for lh in lm_head_list:
|
||||||
lm_head_weights = []
|
lm_head_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in lh:
|
for l in lh:
|
||||||
lm_head_weights.append(l.weight)
|
lm_head_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(lm_head_weights, axis=0),
|
||||||
|
torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(lm_head_weights, axis=0),
|
weights.append((torch.stack(lm_head_weights, axis=0),
|
||||||
torch.stack(scales, axis=0)))
|
torch.stack(scales, axis=0)))
|
||||||
if isinstance(weights[0], tuple):
|
if isinstance(weights[0], tuple):
|
||||||
|
|
@ -202,7 +226,8 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
model_norm_weight=model_norm.weight.to(torch.float16),
|
model_norm_weight=model_norm.weight.to(torch.float16),
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
n_splits=n_splits_linear
|
n_splits=n_splits_linear,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
|
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir,
|
||||||
True, True)
|
True, True)
|
||||||
|
|
@ -210,12 +235,24 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir,
|
||||||
# save weights bins files
|
# save weights bins files
|
||||||
if n_splits_linear == 1:
|
if n_splits_linear == 1:
|
||||||
if vocab_size == 122753:
|
if vocab_size == 122753:
|
||||||
|
if not asym:
|
||||||
weight_numpy = [model.lm_head_0.weight.data.numpy(),
|
weight_numpy = [model.lm_head_0.weight.data.numpy(),
|
||||||
model.lm_head_0.scale.data.numpy(),
|
model.lm_head_0.scale.data.numpy(),
|
||||||
model.lm_head_1.weight.data.numpy(),
|
model.lm_head_1.weight.data.numpy(),
|
||||||
model.lm_head_1.scale.data.numpy(), ]
|
model.lm_head_1.scale.data.numpy(), ]
|
||||||
else:
|
else:
|
||||||
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
|
weight_numpy = [model.lm_head_0.weight.data.numpy(),
|
||||||
|
model.lm_head_0.scale.data.numpy(),
|
||||||
|
model.lm_head_0.zero.data.numpy(),
|
||||||
|
model.lm_head_1.weight.data.numpy(),
|
||||||
|
model.lm_head_1.scale.data.numpy(),
|
||||||
|
model.lm_head_1.zero.data.numpy(), ]
|
||||||
|
else:
|
||||||
|
if not asym:
|
||||||
|
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy()]
|
||||||
|
else:
|
||||||
|
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(),
|
||||||
|
model.lm_head.zero.data.numpy()]
|
||||||
else:
|
else:
|
||||||
weight_numpy = [v.numpy() for v in weights[0]]
|
weight_numpy = [v.numpy() for v in weights[0]]
|
||||||
if vocab_size == 122753:
|
if vocab_size == 122753:
|
||||||
|
|
@ -266,6 +303,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
num_hidden_layers = model.config.num_hidden_layers
|
num_hidden_layers = model.config.num_hidden_layers
|
||||||
scale_depth = model.model.config.scale_depth
|
scale_depth = model.model.config.scale_depth
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
|
from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
|
|
@ -279,9 +317,16 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -321,7 +366,8 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
||||||
decoder_name,
|
decoder_name,
|
||||||
|
|
@ -337,11 +383,23 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
|
layer_norm_0.data.numpy().tofile(input_lm_bin_file)
|
||||||
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
|
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
|
||||||
st_idx = 7
|
st_idx = 7
|
||||||
|
if not asym:
|
||||||
for idx, (weight, scale) in enumerate(weights):
|
for idx, (weight, scale) in enumerate(weights):
|
||||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
|
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
|
||||||
weight.numpy().tofile(bin_file)
|
weight.numpy().tofile(bin_file)
|
||||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
|
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
|
||||||
scale.numpy().tofile(bin_file)
|
scale.numpy().tofile(bin_file)
|
||||||
|
else:
|
||||||
|
for idx, (weight, scale, zero) in enumerate(weights):
|
||||||
|
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
|
||||||
|
weight.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
|
||||||
|
scale.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
|
||||||
|
zero.numpy().tofile(bin_file)
|
||||||
|
|
||||||
del single_decoder
|
del single_decoder
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -357,6 +415,7 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
|
||||||
scale_depth = model.model.config.scale_depth
|
scale_depth = model.model.config.scale_depth
|
||||||
layer_num = len(model.model.layers)
|
layer_num = len(model.model.layers)
|
||||||
fused_layer_num = layer_num // fused_layers
|
fused_layer_num = layer_num // fused_layers
|
||||||
|
asym = getattr(model.config, "asym", False)
|
||||||
|
|
||||||
from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
|
from ipex_llm.transformers.npu_models.minicpm_mp import LowBitMinicpmMultiDecoderlayer
|
||||||
for i in range(fused_layers):
|
for i in range(fused_layers):
|
||||||
|
|
@ -380,9 +439,16 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
|
||||||
mlp_layer.down_proj_dq_list]:
|
mlp_layer.down_proj_dq_list]:
|
||||||
l_weights = []
|
l_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
|
zeros = []
|
||||||
for l in layer_list:
|
for l in layer_list:
|
||||||
l_weights.append(l.weight)
|
l_weights.append(l.weight)
|
||||||
scales.append(l.scale)
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -401,12 +467,25 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
|
||||||
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
|
layer_norm_1.data.numpy().tofile(post_lm_bin_file)
|
||||||
st_idx = 5
|
st_idx = 5
|
||||||
# 6, 7 are past k/v
|
# 6, 7 are past k/v
|
||||||
|
if not asym:
|
||||||
for idx, (weight, scale) in enumerate(weights):
|
for idx, (weight, scale) in enumerate(weights):
|
||||||
bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+idx*2}.bin")
|
||||||
weight.numpy().tofile(bin_file)
|
weight.numpy().tofile(bin_file)
|
||||||
bin_file = os.path.join(weight_dir,
|
bin_file = os.path.join(weight_dir,
|
||||||
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
|
f"model_{layer_idx}_input_{st_idx+idx*2+1}.bin")
|
||||||
scale.numpy().tofile(bin_file)
|
scale.numpy().tofile(bin_file)
|
||||||
|
else:
|
||||||
|
for idx, (weight, scale, zero) in enumerate(weights):
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+idx*3}.bin")
|
||||||
|
weight.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+idx*3+1}.bin")
|
||||||
|
scale.numpy().tofile(bin_file)
|
||||||
|
bin_file = os.path.join(weight_dir,
|
||||||
|
f"model_{layer_idx}_input_{st_idx+idx*3+2}.bin")
|
||||||
|
zero.numpy().tofile(bin_file)
|
||||||
|
|
||||||
if isinstance(weights[0], tuple):
|
if isinstance(weights[0], tuple):
|
||||||
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
|
@ -432,7 +511,8 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
n_splits_linear=n_splits_linear,
|
n_splits_linear=n_splits_linear,
|
||||||
n_splits_down_proj=n_splits_down_proj,
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
group_size=group_size
|
group_size=group_size,
|
||||||
|
asym=asym
|
||||||
)
|
)
|
||||||
update_names_of_IR_and_export_blob(fused_decoder,
|
update_names_of_IR_and_export_blob(fused_decoder,
|
||||||
f"decoder_layer_{i}",
|
f"decoder_layer_{i}",
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue