[NPU] Llama3, Qwen2 1.5b, MiniCPM 1/2B groupwise support (#12327)
* support minicpm 1b & qwen 1.5b gw * support minicpm 1b * support minicpm 2b * fix style & error * fix style & update * remove print
This commit is contained in:
parent
82a61b5cf3
commit
d872639395
9 changed files with 239 additions and 68 deletions
|
|
@ -47,6 +47,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||||
parser.add_argument("--max-context-len", type=int, default=1024)
|
parser.add_argument("--max-context-len", type=int, default=1024)
|
||||||
parser.add_argument("--max-prompt-len", type=int, default=512)
|
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||||
|
parser.add_argument("--quantization_group_size", type=int, default=0)
|
||||||
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||||
parser.add_argument("--disable-streaming", action="store_true", default=False)
|
parser.add_argument("--disable-streaming", action="store_true", default=False)
|
||||||
|
|
||||||
|
|
@ -61,6 +62,7 @@ if __name__ == "__main__":
|
||||||
max_prompt_len=args.max_prompt_len,
|
max_prompt_len=args.max_prompt_len,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
attn_implementation="eager",
|
attn_implementation="eager",
|
||||||
|
quantization_group_size=args.quantization_group_size,
|
||||||
transpose_value_cache=not args.disable_transpose_value_cache,
|
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -76,13 +76,19 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
|
||||||
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
|
||||||
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
|
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
|
||||||
if isinstance(module, (Qwen2Attention, LlamaAttention)):
|
if (
|
||||||
|
isinstance(module, (Qwen2Attention, LlamaAttention))
|
||||||
|
or module.__class__.__name__ in ['MiniCPMAttention', 'Attention']
|
||||||
|
):
|
||||||
for name in attn_module_names:
|
for name in attn_module_names:
|
||||||
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
||||||
n_splits=n_splits_hidden_size,
|
n_splits=n_splits_hidden_size,
|
||||||
load=load))
|
load=load))
|
||||||
delattr(module, name)
|
delattr(module, name)
|
||||||
elif isinstance(module, (Qwen2MLP, LlamaMLP)):
|
elif (
|
||||||
|
isinstance(module, (Qwen2MLP, LlamaMLP))
|
||||||
|
or module.__class__.__name__ in ['MiniCPMMLP', 'MLP']
|
||||||
|
):
|
||||||
for name in mlp_module_names:
|
for name in mlp_module_names:
|
||||||
n_splits_mlp = n_splits_hidden_size
|
n_splits_mlp = n_splits_hidden_size
|
||||||
if name == 'down_proj':
|
if name == 'down_proj':
|
||||||
|
|
|
||||||
|
|
@ -87,9 +87,8 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
model.llm.config.model_type = "llama"
|
model.llm.config.model_type = "llama"
|
||||||
model = model.llm
|
model = model.llm
|
||||||
|
|
||||||
if model.config.model_type in ["qwen2", "llama"]:
|
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||||
from ipex_llm.transformers.npu_models.common import split_linears
|
from ipex_llm.transformers.npu_models.common import split_linears
|
||||||
|
|
||||||
if quantization_group_size == 0:
|
if quantization_group_size == 0:
|
||||||
n_splits_linear = 1
|
n_splits_linear = 1
|
||||||
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
|
n_splits_down_proj = 2 if model.config.intermediate_size == 18944 else 1
|
||||||
|
|
@ -110,6 +109,17 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
||||||
|
|
||||||
if quantization_group_size != 0:
|
if quantization_group_size != 0:
|
||||||
split_num = model.config.hidden_size // quantization_group_size
|
split_num = model.config.hidden_size // quantization_group_size
|
||||||
|
if model.config.model_type == "minicpm" and model.config.num_hidden_layers == 40:
|
||||||
|
# workaround for MiniCPM-2B
|
||||||
|
new_lm_head_0 = SlicedLMHead(model.lm_head_0.weight, split_num=split_num,
|
||||||
|
bias=model.lm_head_0.bias, use_split=True)
|
||||||
|
del model.lm_head_0
|
||||||
|
model.lm_head_0 = new_lm_head_0
|
||||||
|
new_lm_head_1 = SlicedLMHead(model.lm_head_1.weight, split_num=split_num,
|
||||||
|
bias=model.lm_head_1.bias, use_split=True)
|
||||||
|
del model.lm_head_1
|
||||||
|
model.lm_head_1 = new_lm_head_1
|
||||||
|
else:
|
||||||
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
new_lm_head = SlicedLMHead(model.lm_head.weight, split_num=split_num,
|
||||||
bias=model.lm_head.bias, use_split=True)
|
bias=model.lm_head.bias, use_split=True)
|
||||||
del model.lm_head
|
del model.lm_head
|
||||||
|
|
@ -372,6 +382,10 @@ def optimize_llm(
|
||||||
transpose_value_cache=transpose_value_cache)
|
transpose_value_cache=transpose_value_cache)
|
||||||
if hasattr(model, 'lm_head') and isinstance(model.lm_head, SlicedLMHead):
|
if hasattr(model, 'lm_head') and isinstance(model.lm_head, SlicedLMHead):
|
||||||
model.lm_head.get_fused_lm_head()
|
model.lm_head.get_fused_lm_head()
|
||||||
|
# MiniCPM-2b
|
||||||
|
if hasattr(model, "lm_head_1") and isinstance(model.lm_head_1, SlicedLMHead):
|
||||||
|
model.lm_head_1.get_fused_lm_head()
|
||||||
|
model.lm_head_0.get_fused_lm_head()
|
||||||
|
|
||||||
|
|
||||||
def optimize_funasr(
|
def optimize_funasr(
|
||||||
|
|
|
||||||
|
|
@ -110,8 +110,8 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
# define input, the order self.parameter matters
|
# define input, the order self.parameter matters
|
||||||
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||||
|
|
||||||
# llama2 use ov sdp, other models need to test
|
# llama2/3 use ov sdp, other models need to test
|
||||||
use_prefill_sdp = self.intermediate_size == 11008
|
use_prefill_sdp = self.intermediate_size in [11008, 14336]
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if mode == "decode":
|
if mode == "decode":
|
||||||
|
|
@ -437,7 +437,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
self.use_prefill_sdp = intermediate_size == 11008
|
self.use_prefill_sdp = intermediate_size in [11008, 14336]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -78,13 +78,19 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
rms_norm_eps,
|
rms_norm_eps,
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
scale_depth,
|
scale_depth,
|
||||||
num_hidden_layers
|
num_hidden_layers,
|
||||||
|
n_splits_linear: int = 1,
|
||||||
|
n_splits_down_proj: int = 1,
|
||||||
|
group_size: int = 0
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
profile=profile,
|
profile=profile,
|
||||||
device=device)
|
device=device,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size)
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
@ -235,7 +241,7 @@ class LowBitMinicpmMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
attn_output * layer_scale_depth)
|
attn_output * layer_scale_depth)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
|
||||||
hidden_states = self.eltwise_add(residual,
|
hidden_states = self.eltwise_add(residual,
|
||||||
hidden_states * layer_scale_depth)
|
hidden_states * layer_scale_depth)
|
||||||
hidden_states = self.convert_to_fp16(hidden_states)
|
hidden_states = self.convert_to_fp16(hidden_states)
|
||||||
|
|
@ -264,6 +270,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
max_seq_len: int = 1024,
|
max_seq_len: int = 1024,
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
do_print: bool = False,
|
do_print: bool = False,
|
||||||
|
n_splits_linear: int = 1,
|
||||||
|
n_splits_down_proj: int = 1,
|
||||||
|
group_size: int = 0
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -273,6 +282,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
for w in parameters:
|
for w in parameters:
|
||||||
if isinstance(w, tuple): # from QuantizedLinear
|
if isinstance(w, tuple): # from QuantizedLinear
|
||||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||||
|
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||||
|
op_parameters.append(w.numpy())
|
||||||
|
elif isinstance(w, np.ndarray): # scale
|
||||||
|
op_parameters.append(w)
|
||||||
else:
|
else:
|
||||||
op_parameters.append(w.to(torch.float16).numpy())
|
op_parameters.append(w.to(torch.float16).numpy())
|
||||||
self.op_parameters = op_parameters
|
self.op_parameters = op_parameters
|
||||||
|
|
@ -281,6 +294,10 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
self.transpose_value = transpose_value
|
self.transpose_value = transpose_value
|
||||||
if isinstance(parameters[0], tuple):
|
if isinstance(parameters[0], tuple):
|
||||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||||
|
elif parameters[0].dtype == torch.int8:
|
||||||
|
np_dtype = np.int8
|
||||||
|
elif parameters[0].dtype == torch.uint8:
|
||||||
|
np_dtype = np.uint8
|
||||||
else: # FP16 Linear
|
else: # FP16 Linear
|
||||||
np_dtype = np.float16
|
np_dtype = np.float16
|
||||||
|
|
||||||
|
|
@ -317,6 +334,9 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module):
|
||||||
mode="decode",
|
mode="decode",
|
||||||
transpose_value=self.transpose_value,
|
transpose_value=self.transpose_value,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
self.backend_decoders.append(decoder)
|
self.backend_decoders.append(decoder)
|
||||||
|
|
||||||
|
|
@ -392,6 +412,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
num_hidden_layers,
|
num_hidden_layers,
|
||||||
max_seq_len: int = 128,
|
max_seq_len: int = 128,
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
|
n_splits_linear: int = 1,
|
||||||
|
n_splits_down_proj: int = 1,
|
||||||
|
group_size: int = 0
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.op_parameters = parameters
|
self.op_parameters = parameters
|
||||||
|
|
@ -422,6 +445,9 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
mode="prefill",
|
mode="prefill",
|
||||||
transpose_value=self.transpose_value,
|
transpose_value=self.transpose_value,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
|
@ -501,24 +527,53 @@ def run_decode(
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
intermediate_size = model.config.intermediate_size
|
intermediate_size = model.config.intermediate_size
|
||||||
num_hidden_layers = model.config.num_hidden_layers
|
num_hidden_layers = model.config.num_hidden_layers
|
||||||
|
group_size = getattr(model.config, "group_size", 0)
|
||||||
layer_weights = []
|
layer_weights = []
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
weights = [
|
weights = []
|
||||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
if n_splits_linear == 1:
|
||||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
attn_layer.k_proj_dq_list,
|
||||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
attn_layer.v_proj_dq_list,
|
||||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
attn_layer.o_proj_dq_list,
|
||||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
mlp_layer.gate_proj_dq_list,
|
||||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
mlp_layer.up_proj_dq_list):
|
||||||
]
|
weights.append((q.weight, q.scale))
|
||||||
|
weights.append((k.weight, k.scale))
|
||||||
|
weights.append((v.weight, v.scale))
|
||||||
|
weights.append((o.weight, o.scale))
|
||||||
|
weights.append((g.weight, g.scale))
|
||||||
|
weights.append((u.weight, u.scale))
|
||||||
|
else:
|
||||||
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in layer_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
|
if n_splits_down_proj == 1:
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
else:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
|
@ -547,6 +602,9 @@ def run_decode(
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
do_print=False,
|
do_print=False,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
@ -711,25 +769,55 @@ def run_prefill(
|
||||||
intermediate_size = model.config.intermediate_size
|
intermediate_size = model.config.intermediate_size
|
||||||
scale_depth = model.config.scale_depth
|
scale_depth = model.config.scale_depth
|
||||||
num_hidden_layers = model.config.num_hidden_layers
|
num_hidden_layers = model.config.num_hidden_layers
|
||||||
|
group_size = getattr(model.config, "group_size", 0)
|
||||||
deocderlayers = []
|
deocderlayers = []
|
||||||
layer_weights = []
|
layer_weights = []
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
|
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||||
|
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
|
||||||
weights = [
|
weights = []
|
||||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
|
||||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
if n_splits_linear == 1:
|
||||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
attn_layer.k_proj_dq_list,
|
||||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
attn_layer.v_proj_dq_list,
|
||||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
attn_layer.o_proj_dq_list,
|
||||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
mlp_layer.gate_proj_dq_list,
|
||||||
]
|
mlp_layer.up_proj_dq_list):
|
||||||
|
weights.append((q.weight, q.scale))
|
||||||
|
weights.append((k.weight, k.scale))
|
||||||
|
weights.append((v.weight, v.scale))
|
||||||
|
weights.append((o.weight, o.scale))
|
||||||
|
weights.append((g.weight, g.scale))
|
||||||
|
weights.append((u.weight, u.scale))
|
||||||
|
else:
|
||||||
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in layer_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
|
if n_splits_down_proj == 1:
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
else:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
|
@ -752,6 +840,9 @@ def run_prefill(
|
||||||
num_hidden_layers=num_hidden_layers,
|
num_hidden_layers=num_hidden_layers,
|
||||||
max_seq_len=max_output_len,
|
max_seq_len=max_output_len,
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights.extend(weights)
|
layer_weights.extend(weights)
|
||||||
|
|
|
||||||
|
|
@ -273,7 +273,6 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
self.n_splits_linear, wt_dtype=self.dtype,
|
self.n_splits_linear, wt_dtype=self.dtype,
|
||||||
scale_factor=(self.group_size == 0),
|
scale_factor=(self.group_size == 0),
|
||||||
is_prefill=(mode == "prefill"))
|
is_prefill=(mode == "prefill"))
|
||||||
|
|
||||||
return attn_output, new_key_states, new_value_states
|
return attn_output, new_key_states, new_value_states
|
||||||
|
|
||||||
def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
|
def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
|
||||||
|
|
|
||||||
|
|
@ -370,6 +370,9 @@ def convert_llm(model: torch.nn.Module,
|
||||||
|
|
||||||
if hasattr(model, "lm_head") and isinstance(model.lm_head, SlicedLMHead):
|
if hasattr(model, "lm_head") and isinstance(model.lm_head, SlicedLMHead):
|
||||||
model.lm_head.get_fused_lm_head()
|
model.lm_head.get_fused_lm_head()
|
||||||
|
if hasattr(model, "lm_head_1") and isinstance(model.lm_head_1, SlicedLMHead):
|
||||||
|
model.lm_head_1.get_fused_lm_head()
|
||||||
|
model.lm_head_0.get_fused_lm_head()
|
||||||
|
|
||||||
# patch generate function
|
# patch generate function
|
||||||
import types
|
import types
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ class MiniCPMLMHead(LLMBaseNNFactory):
|
||||||
transpose_value: bool = False,
|
transpose_value: bool = False,
|
||||||
profile: bool = False,
|
profile: bool = False,
|
||||||
device: str = "NPU",
|
device: str = "NPU",
|
||||||
|
n_splits: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__(max_seq_len=max_seq_len,
|
super().__init__(max_seq_len=max_seq_len,
|
||||||
transpose_value=transpose_value,
|
transpose_value=transpose_value,
|
||||||
|
|
@ -108,18 +109,36 @@ class MiniCPMLMHead(LLMBaseNNFactory):
|
||||||
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
|
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
|
||||||
if vocab_size == 122753:
|
if vocab_size == 122753:
|
||||||
# for MiniCPM-2B-sft-bf16
|
# for MiniCPM-2B-sft-bf16
|
||||||
|
if n_splits == 1:
|
||||||
hidden_states_1 = self.linear(
|
hidden_states_1 = self.linear(
|
||||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
|
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||||
)
|
)
|
||||||
hidden_states_2 = self.linear(
|
hidden_states_2 = self.linear(
|
||||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
|
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
hidden_states_1 = self.dq_split_linear(
|
||||||
|
hidden_states, 73440, self.hidden_size,
|
||||||
|
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
|
||||||
|
)
|
||||||
|
hidden_states_2 = self.dq_split_linear(
|
||||||
|
hidden_states, 73440, self.hidden_size,
|
||||||
|
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
|
hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
|
||||||
hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
|
hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
|
||||||
else:
|
else:
|
||||||
# for MiniCPM-1B-sft-bf16
|
# for MiniCPM-1B-sft-bf16
|
||||||
|
if n_splits == 1:
|
||||||
hidden_states = self.linear(
|
hidden_states = self.linear(
|
||||||
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
|
hidden_states, self.vocab_size, self.hidden_size, bias=False,
|
||||||
|
wt_dtype=self.dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.dq_split_linear(
|
||||||
|
hidden_states, self.vocab_size, self.hidden_size,
|
||||||
|
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# define outputs
|
# define outputs
|
||||||
|
|
@ -145,8 +164,19 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
||||||
# for MiniCPM-1B-sft-bf16
|
# for MiniCPM-1B-sft-bf16
|
||||||
weights = [(model.lm_head.weight, model.lm_head.scale)]
|
weights = [(model.lm_head.weight, model.lm_head.scale)]
|
||||||
else:
|
else:
|
||||||
# TODO
|
weights = []
|
||||||
pass
|
if vocab_size == 122753:
|
||||||
|
lm_head_list = [model.lm_head_0.lm_heads, model.lm_head_1.lm_heads]
|
||||||
|
else:
|
||||||
|
lm_head_list = [model.lm_head.lm_heads]
|
||||||
|
for lh in lm_head_list:
|
||||||
|
lm_head_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in lh:
|
||||||
|
lm_head_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(lm_head_weights, axis=0),
|
||||||
|
torch.stack(scales, axis=0)))
|
||||||
if isinstance(weights[0], tuple):
|
if isinstance(weights[0], tuple):
|
||||||
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
||||||
else: # FP16 Linear
|
else: # FP16 Linear
|
||||||
|
|
@ -162,6 +192,7 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
model_norm_weight=model_norm.weight.to(torch.float16),
|
model_norm_weight=model_norm.weight.to(torch.float16),
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
|
n_splits=n_splits_linear
|
||||||
)
|
)
|
||||||
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
|
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
|
||||||
|
|
||||||
|
|
@ -175,8 +206,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
||||||
else:
|
else:
|
||||||
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
|
weight_numpy = [model.lm_head.weight.data.numpy(), model.lm_head.scale.data.numpy(), ]
|
||||||
else:
|
else:
|
||||||
# TODO
|
weight_numpy = [v.numpy() for v in weights[0]]
|
||||||
pass
|
if vocab_size == 122753:
|
||||||
|
weight_numpy.extend([v.numpy() for v in weights[1]])
|
||||||
|
|
||||||
for idx, weight in enumerate(weight_numpy):
|
for idx, weight in enumerate(weight_numpy):
|
||||||
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
|
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
|
||||||
|
|
@ -214,18 +246,40 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
|
|
||||||
weights = []
|
weights = []
|
||||||
if n_splits_linear == 1:
|
if n_splits_linear == 1:
|
||||||
weights = [
|
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||||
(attn_layer.q_proj.weight, attn_layer.q_proj.scale),
|
attn_layer.k_proj_dq_list,
|
||||||
(attn_layer.k_proj.weight, attn_layer.k_proj.scale),
|
attn_layer.v_proj_dq_list,
|
||||||
(attn_layer.v_proj.weight, attn_layer.v_proj.scale),
|
attn_layer.o_proj_dq_list,
|
||||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
mlp_layer.gate_proj_dq_list,
|
||||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
mlp_layer.up_proj_dq_list):
|
||||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
weights.append((q.weight, q.scale))
|
||||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
weights.append((k.weight, k.scale))
|
||||||
]
|
weights.append((v.weight, v.scale))
|
||||||
|
weights.append((o.weight, o.scale))
|
||||||
|
weights.append((g.weight, g.scale))
|
||||||
|
weights.append((u.weight, u.scale))
|
||||||
else:
|
else:
|
||||||
# TODO
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
pass
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in layer_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0),
|
||||||
|
torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
|
if n_splits_down_proj == 1:
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
weights.append((l.weight, l.scale))
|
||||||
|
else:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
for l in mlp_layer.down_proj_dq_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
|
|
@ -254,6 +308,9 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
mode="decode",
|
mode="decode",
|
||||||
transpose_value=transpose_value_cache,
|
transpose_value=transpose_value_cache,
|
||||||
dtype=np_dtype,
|
dtype=np_dtype,
|
||||||
|
n_splits_linear=n_splits_linear,
|
||||||
|
n_splits_down_proj=n_splits_down_proj,
|
||||||
|
group_size=group_size
|
||||||
)
|
)
|
||||||
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
||||||
f"decoder_layer_{layer_idx}",
|
f"decoder_layer_{layer_idx}",
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
|
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
|
||||||
|
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
|
||||||
|
|
||||||
|
|
||||||
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
||||||
|
|
@ -27,18 +28,16 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
||||||
rms_norm_eps = model.config.rms_norm_eps
|
rms_norm_eps = model.config.rms_norm_eps
|
||||||
vocab_size = model.config.vocab_size
|
vocab_size = model.config.vocab_size
|
||||||
model_norm = model.model.norm
|
model_norm = model.model.norm
|
||||||
if model.config.intermediate_size == 18944:
|
lm_head = model.lm_head
|
||||||
lm_heads = model.lm_head.lm_heads # Qwen2-7B is always SlicedLMHead
|
if not isinstance(lm_head, SlicedLMHead):
|
||||||
else:
|
weights = [(lm_head.weight, lm_head.scale)]
|
||||||
lm_heads = [model.lm_head]
|
|
||||||
if n_splits_linear == 1:
|
|
||||||
weights = [(lm_heads[0].weight, lm_heads[0].scale)]
|
|
||||||
else:
|
else:
|
||||||
|
lm_heads = lm_head.lm_heads
|
||||||
lm_head_weights = []
|
lm_head_weights = []
|
||||||
scales = []
|
scales = []
|
||||||
for i in range(n_splits_linear):
|
for l in lm_heads:
|
||||||
lm_head_weights.append(lm_heads[i].weight)
|
lm_head_weights.append(l.weight)
|
||||||
scales.append(lm_heads[i].scale)
|
scales.append(l.scale)
|
||||||
weights = [(torch.stack(lm_head_weights, axis=0),
|
weights = [(torch.stack(lm_head_weights, axis=0),
|
||||||
torch.stack(scales, axis=0))]
|
torch.stack(scales, axis=0))]
|
||||||
if isinstance(weights[0], tuple):
|
if isinstance(weights[0], tuple):
|
||||||
|
|
@ -61,9 +60,9 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
||||||
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
|
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
|
||||||
|
|
||||||
# save weights bins files
|
# save weights bins files
|
||||||
if n_splits_linear == 1:
|
if not isinstance(lm_head, SlicedLMHead):
|
||||||
weight_numpy = [
|
weight_numpy = [
|
||||||
lm_heads[0].weight.data.numpy(), lm_heads[0].scale.data.numpy(),
|
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
weight_numpy = [v.numpy() for v in weights[0]]
|
weight_numpy = [v.numpy() for v in weights[0]]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue