[NPU] Compatible with other third-party models like auto-round (#12620)
* support third party model * simplify code * fix sty;e * fix sym int4 GW * code refactor * fix
This commit is contained in:
parent
a9abde0b5d
commit
bbdbbb0d88
5 changed files with 64 additions and 136 deletions
|
|
@ -162,6 +162,7 @@ class QuantizedLinear(torch.nn.Module):
|
||||||
self.zero = None
|
self.zero = None
|
||||||
if group_size != 0:
|
if group_size != 0:
|
||||||
self.scale = Parameter(scale, requires_grad=False)
|
self.scale = Parameter(scale, requires_grad=False)
|
||||||
|
if zero is not None:
|
||||||
self.zero = Parameter(zero, requires_grad=False)
|
self.zero = Parameter(zero, requires_grad=False)
|
||||||
else:
|
else:
|
||||||
if self.weight.dtype == torch.uint8:
|
if self.weight.dtype == torch.uint8:
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
from intel_npu_acceleration_library.backend.factory import NNFactory
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def update_names_of_IR_and_export_blob(model, model_name, dir, compile_blob=True, keep_ir=True,
|
def update_names_of_IR_and_export_blob(model, model_name, dir, compile_blob=True, keep_ir=True,
|
||||||
|
|
@ -170,3 +171,48 @@ class LLMEmbedding(NNFactory):
|
||||||
|
|
||||||
print("start compiling")
|
print("start compiling")
|
||||||
self.compile()
|
self.compile()
|
||||||
|
|
||||||
|
|
||||||
|
def obtain_weight_from_single_layer(attn_layer, mlp_layer):
|
||||||
|
weights = []
|
||||||
|
if hasattr(attn_layer, "q_proj_dq_list"):
|
||||||
|
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||||
|
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||||
|
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||||
|
mlp_layer.down_proj_dq_list]:
|
||||||
|
l_weights = []
|
||||||
|
scales = []
|
||||||
|
zeros = []
|
||||||
|
for l in layer_list:
|
||||||
|
l_weights.append(l.weight)
|
||||||
|
scales.append(l.scale)
|
||||||
|
if l.zero is not None:
|
||||||
|
zeros.append(l.zero)
|
||||||
|
if len(zeros):
|
||||||
|
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
||||||
|
torch.stack(zeros, axis=0)))
|
||||||
|
else:
|
||||||
|
weights.append((torch.stack(l_weights, axis=0),
|
||||||
|
torch.stack(scales, axis=0)))
|
||||||
|
else:
|
||||||
|
for layer in [attn_layer.q_proj, attn_layer.k_proj,
|
||||||
|
attn_layer.v_proj, attn_layer.o_proj,
|
||||||
|
mlp_layer.gate_proj, mlp_layer.up_proj,
|
||||||
|
mlp_layer.down_proj]:
|
||||||
|
if layer.zero is not None:
|
||||||
|
weights.append((layer.weight, layer.scale, layer.zero))
|
||||||
|
else:
|
||||||
|
weights.append((layer.weight, layer.scale))
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def obtain_qkv_bias_from_single_layer(attn_layer):
|
||||||
|
if hasattr(attn_layer, "q_proj_dq_list"):
|
||||||
|
q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
|
||||||
|
k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)
|
||||||
|
v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)
|
||||||
|
else:
|
||||||
|
q_bias = attn_layer.q_proj.bias.to(torch.float16)
|
||||||
|
k_bias = attn_layer.k_proj.bias.to(torch.float16)
|
||||||
|
v_bias = attn_layer.v_proj.bias.to(torch.float16)
|
||||||
|
return q_bias, k_bias, v_bias
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,8 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
|
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \
|
||||||
|
obtain_weight_from_single_layer
|
||||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
from intel_npu_acceleration_library.backend.factory import NNFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -261,26 +262,7 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
|
||||||
weights = []
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
|
||||||
mlp_layer.down_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
zeros = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
if l.zero is not None:
|
|
||||||
zeros.append(l.zero)
|
|
||||||
if len(zeros):
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
|
||||||
torch.stack(zeros, axis=0)))
|
|
||||||
else:
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||||
# llama-2-7B & llama-3-8B
|
# llama-2-7B & llama-3-8B
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
@ -400,32 +382,11 @@ def convert_fused_llama_layer(model, fused_layers, n_splits_linear, n_splits_dow
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
|
||||||
weights = []
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
|
||||||
mlp_layer.down_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
zeros = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
if l.zero is not None:
|
|
||||||
zeros.append(l.zero)
|
|
||||||
if len(zeros):
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
|
||||||
torch.stack(zeros, axis=0)))
|
|
||||||
else:
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||||
# llama-2-7B & llama-3-8B
|
# llama-2-7B & llama-3-8B
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from .common import update_names_of_IR_and_export_blob
|
from .common import update_names_of_IR_and_export_blob, obtain_weight_from_single_layer
|
||||||
from intel_npu_acceleration_library.backend.factory import NNFactory
|
from intel_npu_acceleration_library.backend.factory import NNFactory
|
||||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
@ -309,26 +309,7 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
|
||||||
weights = []
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
|
||||||
mlp_layer.down_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
zeros = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
if l.zero is not None:
|
|
||||||
zeros.append(l.zero)
|
|
||||||
if len(zeros):
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
|
||||||
torch.stack(zeros, axis=0)))
|
|
||||||
else:
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
|
@ -425,32 +406,11 @@ def convert_fused_minicpm_layer(model, fused_layers, n_splits_linear, n_splits_d
|
||||||
input_layer_norm_weights = []
|
input_layer_norm_weights = []
|
||||||
post_attn_layernorm_weights = []
|
post_attn_layernorm_weights = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
|
||||||
weights = []
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
|
||||||
mlp_layer.down_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
zeros = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
if l.zero is not None:
|
|
||||||
zeros.append(l.zero)
|
|
||||||
if len(zeros):
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
|
||||||
torch.stack(zeros, axis=0)))
|
|
||||||
else:
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,8 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead
|
from .common import update_names_of_IR_and_export_blob, LLMEmbedding, LowBitLLMLMHead, \
|
||||||
|
obtain_weight_from_single_layer, obtain_qkv_bias_from_single_layer
|
||||||
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
|
from ipex_llm.transformers.npu_models.lm_head import SlicedLMHead
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -132,29 +133,8 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
|
||||||
weights = []
|
q_bias, k_bias, v_bias = obtain_qkv_bias_from_single_layer(attn_layer)
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
|
||||||
mlp_layer.down_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
zeros = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
if l.zero is not None:
|
|
||||||
zeros.append(l.zero)
|
|
||||||
if len(zeros):
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
|
||||||
torch.stack(zeros, axis=0)))
|
|
||||||
else:
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
|
|
||||||
k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)
|
|
||||||
v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
|
@ -263,32 +243,11 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
|
||||||
k_biases = []
|
k_biases = []
|
||||||
v_biases = []
|
v_biases = []
|
||||||
layer_indexs = range(layer_start, layer_end)
|
layer_indexs = range(layer_start, layer_end)
|
||||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
|
||||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
|
||||||
for layer_idx in layer_indexs:
|
for layer_idx in layer_indexs:
|
||||||
curr_layer = model.model.layers[layer_idx]
|
curr_layer = model.model.layers[layer_idx]
|
||||||
attn_layer = curr_layer.self_attn
|
attn_layer = curr_layer.self_attn
|
||||||
mlp_layer = curr_layer.mlp
|
mlp_layer = curr_layer.mlp
|
||||||
|
weights = obtain_weight_from_single_layer(attn_layer, mlp_layer)
|
||||||
weights = []
|
|
||||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
|
||||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
|
||||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
|
||||||
mlp_layer.down_proj_dq_list]:
|
|
||||||
l_weights = []
|
|
||||||
scales = []
|
|
||||||
zeros = []
|
|
||||||
for l in layer_list:
|
|
||||||
l_weights.append(l.weight)
|
|
||||||
scales.append(l.scale)
|
|
||||||
if l.zero is not None:
|
|
||||||
zeros.append(l.zero)
|
|
||||||
if len(zeros):
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0),
|
|
||||||
torch.stack(zeros, axis=0)))
|
|
||||||
else:
|
|
||||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
|
||||||
|
|
||||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||||
|
|
@ -297,9 +256,10 @@ def convert_fused_qwen_layer(model, fused_layers, n_splits_linear, n_splits_down
|
||||||
layer_weights.extend(weights)
|
layer_weights.extend(weights)
|
||||||
input_layer_norm_weights.append(layer_norm_0)
|
input_layer_norm_weights.append(layer_norm_0)
|
||||||
post_attn_layernorm_weights.append(layer_norm_1)
|
post_attn_layernorm_weights.append(layer_norm_1)
|
||||||
q_biases.append(attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16))
|
q_bias, k_bias, v_bias = obtain_qkv_bias_from_single_layer(attn_layer)
|
||||||
k_biases.append(attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16))
|
q_biases.append(q_bias)
|
||||||
v_biases.append(attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16))
|
k_biases.append(k_bias)
|
||||||
|
v_biases.append(v_bias)
|
||||||
|
|
||||||
# save weight
|
# save weight
|
||||||
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
|
input_lm_bin_file = os.path.join(weight_dir, f"model_{layer_idx}_input_3.bin")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue