[NPU] Support Baichuan groupwise & gw code refactor (#12337)
* support minicpm 1b & qwen 1.5b gw * support minicpm 1b * baichuan part * update * support minicpm 1b & qwen 1.5b gw * support minicpm 1b * baichuan part * update * update * update * baichuan support * code refactor * remove code * fix style * address comments * revert
This commit is contained in:
parent
812d5cc32e
commit
b2e69a896c
13 changed files with 367 additions and 434 deletions
|
|
@ -60,6 +60,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
|
||||
parser.add_argument("--max-context-len", type=int, default=1024)
|
||||
parser.add_argument("--max-prompt-len", type=int, default=512)
|
||||
parser.add_argument("--quantization_group_size", type=int, default=0)
|
||||
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
|
||||
parser.add_argument("--disable-streaming", action="store_true", default=False)
|
||||
|
||||
|
|
@ -72,6 +73,7 @@ if __name__ == "__main__":
|
|||
pipeline=True,
|
||||
max_context_len=args.max_context_len,
|
||||
max_prompt_len=args.max_prompt_len,
|
||||
quantization_group_size=args.quantization_group_size,
|
||||
torch_dtype=torch.float16,
|
||||
attn_implementation="eager",
|
||||
transpose_value_cache=not args.disable_transpose_value_cache,
|
||||
|
|
|
|||
|
|
@ -50,6 +50,9 @@ from transformers.cache_utils import Cache
|
|||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from ipex_llm.transformers.npu_models.mp_models_base import run_model
|
||||
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
|
||||
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
|
||||
class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
||||
|
|
@ -75,12 +78,18 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
device: str = "NPU",
|
||||
rms_norm_eps,
|
||||
intermediate_size,
|
||||
n_splits_linear: int = 1,
|
||||
n_splits_down_proj: int = 1,
|
||||
group_size: int = 0
|
||||
):
|
||||
super().__init__(max_seq_len=max_seq_len,
|
||||
transpose_value=transpose_value,
|
||||
dtype=dtype,
|
||||
profile=profile,
|
||||
device=device)
|
||||
device=device,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dtype = dtype
|
||||
|
|
@ -115,8 +124,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
|
||||
dtype=np.int64)
|
||||
else:
|
||||
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
|
||||
dtype=np.int64)
|
||||
attention_mask = None
|
||||
|
||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||
# self.num_key_value_heads = num_key_value_heads
|
||||
|
|
@ -178,6 +186,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||
past_key=past_keys[i],
|
||||
past_value=past_values[i],
|
||||
use_prefill_sdp=True,
|
||||
)
|
||||
curr_key_values.append((new_key_states, new_value_states))
|
||||
|
||||
|
|
@ -189,6 +198,9 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
new_value_states = self.convert_to_fp16(curr_key_values[i][1])
|
||||
|
||||
print("start compiling")
|
||||
if mode == "prefill" and os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT", "0") != "1":
|
||||
self.compile(npu_dpu_groups=6)
|
||||
else:
|
||||
self.compile()
|
||||
|
||||
def attention(self,
|
||||
|
|
@ -206,15 +218,23 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
seq_len,
|
||||
q_bias=None,
|
||||
k_bias=None,
|
||||
v_bias=None):
|
||||
v_bias=None,
|
||||
use_prefill_sdp=False):
|
||||
hidden_size = num_heads * head_dim
|
||||
if self.n_splits_linear != 1:
|
||||
hidden_states = self.unsqueeze(hidden_states, axis=0)
|
||||
|
||||
proj = self.linear(
|
||||
hidden_states,
|
||||
3 * hidden_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype
|
||||
wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
|
||||
proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h
|
||||
proj = self.unsqueeze(proj, [0]) # b, s, 3, h
|
||||
proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h
|
||||
|
|
@ -224,8 +244,14 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim])
|
||||
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
||||
value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim])
|
||||
|
||||
use_ov_sdp = (mode == "prefill") and use_prefill_sdp
|
||||
if self.transpose_value:
|
||||
value_states = self.transpose(value_states, [0, 2, 3, 1])
|
||||
new_value_states = self.transpose(value_states, [0, 2, 3, 1])
|
||||
if use_ov_sdp:
|
||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||
else:
|
||||
value_states = new_value_states
|
||||
else:
|
||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||
|
||||
|
|
@ -243,7 +269,6 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
head_dim=head_dim,
|
||||
)
|
||||
new_key_states = key_states
|
||||
new_value_states = value_states
|
||||
|
||||
if self.mode == "decode":
|
||||
key_states = self.concat(past_key, key_states, axis=-2)
|
||||
|
|
@ -252,6 +277,14 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
else:
|
||||
value_states = self.concat(past_value, value_states, axis=-2)
|
||||
|
||||
if use_ov_sdp:
|
||||
value_states = self.convert_to_fp32(value_states)
|
||||
key_states = self.convert_to_fp32(key_states)
|
||||
query_states = self.convert_to_fp32(query_states)
|
||||
attn_output = self.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states, None, True)
|
||||
attn_output = self.convert_to_fp16(attn_output)
|
||||
else:
|
||||
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
||||
math.sqrt(self.head_dim))
|
||||
attention_mask = self.convert_to_fp16(attention_mask)
|
||||
|
|
@ -265,7 +298,10 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
|
||||
|
||||
attn_output = self.linear(
|
||||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
|
||||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
return attn_output, new_key_states, new_value_states
|
||||
|
||||
|
|
@ -278,6 +314,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
post_attention_layernorm_weight,
|
||||
past_key=None,
|
||||
past_value=None,
|
||||
use_prefill_sdp=False,
|
||||
):
|
||||
|
||||
residual = hidden_states
|
||||
|
|
@ -298,12 +335,13 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
|
|||
num_heads=self.num_heads,
|
||||
head_dim=self.head_dim,
|
||||
seq_len=self.seq_len,
|
||||
use_prefill_sdp=use_prefill_sdp,
|
||||
)
|
||||
|
||||
hidden_states = self.eltwise_add(residual, attn_output)
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
|
||||
hidden_states = self.eltwise_add(residual, hidden_states)
|
||||
hidden_states = self.convert_to_fp16(hidden_states)
|
||||
|
||||
|
|
@ -329,6 +367,9 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
max_seq_len: int = 1024,
|
||||
transpose_value: bool = False,
|
||||
do_print: bool = False,
|
||||
n_splits_linear: int = 1,
|
||||
n_splits_down_proj: int = 1,
|
||||
group_size: int = 0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -338,6 +379,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
for w in parameters:
|
||||
if isinstance(w, tuple): # from QuantizedLinear
|
||||
op_parameters.append((w[0].numpy(), w[1].numpy()))
|
||||
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
|
||||
op_parameters.append(w.numpy())
|
||||
elif isinstance(w, np.ndarray): # scale
|
||||
op_parameters.append(w)
|
||||
else:
|
||||
op_parameters.append(w.to(torch.float16).numpy())
|
||||
self.op_parameters = op_parameters
|
||||
|
|
@ -346,6 +391,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
self.transpose_value = transpose_value
|
||||
if isinstance(parameters[0], tuple):
|
||||
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
|
||||
elif parameters[0].dtype == torch.int8:
|
||||
np_dtype = np.int8
|
||||
elif parameters[0].dtype == torch.uint8:
|
||||
np_dtype = np.uint8
|
||||
else: # FP16 Linear
|
||||
np_dtype = np.float16
|
||||
|
||||
|
|
@ -380,6 +429,9 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
|
|||
mode="decode",
|
||||
transpose_value=self.transpose_value,
|
||||
dtype=np_dtype,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
)
|
||||
self.backend_decoders.append(decoder)
|
||||
|
||||
|
|
@ -453,6 +505,9 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
|
|||
intermediate_size,
|
||||
max_seq_len: int = 128,
|
||||
transpose_value: bool = False,
|
||||
n_splits_linear: int = 1,
|
||||
n_splits_down_proj: int = 1,
|
||||
group_size: int = 0
|
||||
):
|
||||
super().__init__()
|
||||
self.op_parameters = parameters
|
||||
|
|
@ -481,6 +536,9 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
|
|||
mode="prefill",
|
||||
transpose_value=self.transpose_value,
|
||||
dtype=np_dtype,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
)
|
||||
self.layer_norm_0 = layer_norm_0
|
||||
self.layer_norm_1 = layer_norm_1
|
||||
|
|
@ -507,7 +565,6 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
|
|||
|
||||
backend_cls = self.backend_cls_prefill
|
||||
inputs = (hidden_states.to(torch.float16),
|
||||
attention_mask.to(torch.int64),
|
||||
position_ids.to(torch.int64))
|
||||
inputs += (self.layer_norm_0, self.layer_norm_1)
|
||||
hidden_states, past_key, past_value = run_model(
|
||||
|
|
@ -557,22 +614,28 @@ def run_decode(
|
|||
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||
rms_norm_eps = model.config.rms_norm_eps
|
||||
intermediate_size = model.config.intermediate_size
|
||||
group_size = getattr(model.config, "group_size", 0)
|
||||
layer_weights = []
|
||||
input_layer_norm_weights = []
|
||||
post_attn_layernorm_weights = []
|
||||
layer_indexs = range(layer_start, layer_end)
|
||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||
for layer_idx in layer_indexs:
|
||||
curr_layer = model.model.layers[layer_idx]
|
||||
attn_layer = curr_layer.self_attn
|
||||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = [
|
||||
(attn_layer.W_pack.weight, attn_layer.W_pack.scale),
|
||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||
]
|
||||
weights = []
|
||||
for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
@ -599,6 +662,9 @@ def run_decode(
|
|||
max_seq_len=max_seq_len,
|
||||
transpose_value=transpose_value_cache,
|
||||
do_print=False,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
|
|
@ -754,23 +820,29 @@ def run_prefill(
|
|||
head_dim = model.model.layers[layer_start].self_attn.head_dim
|
||||
rms_norm_eps = model.config.rms_norm_eps
|
||||
intermediate_size = model.config.intermediate_size
|
||||
group_size = getattr(model.config, "group_size", 0)
|
||||
deocderlayers = []
|
||||
layer_weights = []
|
||||
input_layer_norm_weights = []
|
||||
post_attn_layernorm_weights = []
|
||||
layer_indexs = range(layer_start, layer_end)
|
||||
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
|
||||
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
|
||||
for layer_idx in layer_indexs:
|
||||
curr_layer = model.model.layers[layer_idx]
|
||||
attn_layer = curr_layer.self_attn
|
||||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = [
|
||||
(attn_layer.W_pack.weight, attn_layer.W_pack.scale),
|
||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||
]
|
||||
weights = []
|
||||
for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
@ -791,6 +863,9 @@ def run_prefill(
|
|||
intermediate_size=intermediate_size,
|
||||
max_seq_len=max_output_len,
|
||||
transpose_value=transpose_value_cache,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
)
|
||||
|
||||
layer_weights.extend(weights)
|
||||
|
|
@ -1025,3 +1100,71 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
|
|||
)
|
||||
|
||||
return baichuan_fused_model_forward
|
||||
|
||||
|
||||
def baichuan2_causal_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
# ipex-llm change start
|
||||
hidden_states = reshape_lm_head_input(hidden_states)
|
||||
# ipex-llm change end
|
||||
logits = self.lm_head(hidden_states)
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
softmax_normalizer = shift_logits.max(-1).values ** 2
|
||||
z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels) + z_loss
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -75,10 +75,11 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
|
|||
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
|
||||
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
baichuan_attn_module_names = ["W_pack", "o_proj"]
|
||||
mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
|
||||
if (
|
||||
isinstance(module, (Qwen2Attention, LlamaAttention))
|
||||
or module.__class__.__name__ in ['MiniCPMAttention', 'Attention']
|
||||
or module.__class__.__name__ in ['MiniCPMAttention']
|
||||
):
|
||||
for name in attn_module_names:
|
||||
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
||||
|
|
@ -97,3 +98,10 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
|
|||
n_splits=n_splits_mlp,
|
||||
load=load))
|
||||
delattr(module, name)
|
||||
elif module.__class__.__name__ == 'Attention' and module.config.model_type == 'baichuan':
|
||||
# baichuan attention
|
||||
for name in baichuan_attn_module_names:
|
||||
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
|
||||
n_splits=n_splits_hidden_size,
|
||||
load=load))
|
||||
delattr(module, name)
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
|
|||
model.llm.config.model_type = "llama"
|
||||
model = model.llm
|
||||
|
||||
if model.config.model_type in ["qwen2", "llama", "minicpm"]:
|
||||
if model.config.model_type in ["qwen2", "llama", "minicpm", "baichuan"]:
|
||||
from ipex_llm.transformers.npu_models.common import split_linears
|
||||
if quantization_group_size == 0:
|
||||
n_splits_linear = 1
|
||||
|
|
@ -245,6 +245,8 @@ def convert_baichuan(
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
convert_forward(model, module.BaichuanModel, baichuan_model_forward)
|
||||
from ipex_llm.transformers.npu_models.baichuan_mp import baichuan2_causal_forward
|
||||
convert_forward(model, module.BaichuanForCausalLM, baichuan2_causal_forward)
|
||||
|
||||
|
||||
def convert_minicpm(
|
||||
|
|
@ -392,7 +394,7 @@ def optimize_llm(
|
|||
if intra_pp is None:
|
||||
intra_pp = 2
|
||||
if inter_pp is None:
|
||||
inter_pp = 2
|
||||
inter_pp = 2 if group_size == 0 else 4
|
||||
convert_baichuan(model,
|
||||
max_output_len=max_context_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
|
|
|
|||
|
|
@ -560,23 +560,10 @@ def run_decode(
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
|
|
@ -584,17 +571,6 @@ def run_decode(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
@ -844,40 +820,15 @@ def run_prefill(
|
|||
|
||||
weights = []
|
||||
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0),
|
||||
torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||
|
|
|
|||
|
|
@ -540,23 +540,10 @@ def run_decode(
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
|
|
@ -564,17 +551,6 @@ def run_decode(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
|
|
@ -783,24 +759,10 @@ def run_prefill(
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
|
|
@ -808,17 +770,6 @@ def run_prefill(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
||||
|
|
|
|||
|
|
@ -138,13 +138,18 @@ class LLMBaseNNFactory(NNFactory):
|
|||
use_prefill_sdp=False):
|
||||
hidden_size = num_heads * head_dim
|
||||
num_key_value_groups = num_heads // num_key_value_heads
|
||||
if self.n_splits_linear == 1:
|
||||
if self.n_splits_linear != 1:
|
||||
hidden_states = self.unsqueeze(hidden_states, axis=0)
|
||||
|
||||
query_states = self.linear(
|
||||
hidden_states,
|
||||
num_heads * head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
|
||||
key_states = self.linear(
|
||||
|
|
@ -153,6 +158,9 @@ class LLMBaseNNFactory(NNFactory):
|
|||
hidden_size,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
|
||||
value_states = self.linear(
|
||||
|
|
@ -161,24 +169,10 @@ class LLMBaseNNFactory(NNFactory):
|
|||
hidden_size,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
else:
|
||||
hidden_states = self.unsqueeze(hidden_states, axis=0)
|
||||
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
|
||||
if q_bias is not None:
|
||||
query_states = query_states + q_bias
|
||||
|
|
@ -263,15 +257,12 @@ class LLMBaseNNFactory(NNFactory):
|
|||
attn_output = self.transpose(attn_output, [0, 2, 1, 3])
|
||||
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
|
||||
|
||||
if self.n_splits_linear == 1:
|
||||
attn_output = self.linear(
|
||||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
return attn_output, new_key_states, new_value_states
|
||||
|
||||
def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
|
||||
|
|
@ -434,38 +425,26 @@ class LLMBaseNNFactory(NNFactory):
|
|||
return w_2
|
||||
|
||||
def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
|
||||
if self.n_splits_linear == 1:
|
||||
mm1 = self.linear(
|
||||
hidden_states, self.intermediate_size, self.hidden_size, bias=False,
|
||||
wt_dtype=self.dtype
|
||||
wt_dtype=self.dtype, n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
mm2 = self.linear(
|
||||
hidden_states, self.intermediate_size, self.hidden_size, bias=False,
|
||||
wt_dtype=self.dtype
|
||||
wt_dtype=self.dtype, n_splits=self.n_splits_linear,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill")
|
||||
) # type: ignore[attr-defined]
|
||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
else:
|
||||
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
|
||||
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
|
||||
if self.n_splits_down_proj == 1:
|
||||
hidden_states = self.linear(
|
||||
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
|
||||
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
|
||||
self.n_splits_down_proj, wt_dtype=self.dtype,
|
||||
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype,
|
||||
n_splits=self.n_splits_down_proj,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
is_prefill=(mode == "prefill")
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def layer_norm(self, hidden_states, layernorm_weight):
|
||||
|
|
@ -571,8 +550,26 @@ class LLMBaseNNFactory(NNFactory):
|
|||
self.input_ops.append(op)
|
||||
return op
|
||||
|
||||
def linear(self, *args, **kwargs):
|
||||
op = super().linear(*args, **kwargs)
|
||||
def linear(self,
|
||||
input_node: ctypes._Pointer,
|
||||
output_channels: int,
|
||||
input_channels: int,
|
||||
bias: Optional[bool] = False,
|
||||
act_dtype: npt.DTypeLike = np.float16,
|
||||
wt_dtype: npt.DTypeLike = np.float16,
|
||||
n_splits: int = 1,
|
||||
scale_factor: bool = True,
|
||||
is_prefill: bool = False):
|
||||
if n_splits == 1:
|
||||
op = super().linear(input_node, output_channels,
|
||||
input_channels, bias, act_dtype,
|
||||
wt_dtype, scale_factor=scale_factor)
|
||||
else:
|
||||
op = super().dq_split_linear(input_node, n_splits,
|
||||
output_channels, input_channels,
|
||||
bias=bias, act_dtype=act_dtype,
|
||||
wt_dtype=wt_dtype, scale_factor=scale_factor,
|
||||
is_prefill=is_prefill)
|
||||
self.linear_ops.append(op)
|
||||
return op
|
||||
|
||||
|
|
|
|||
|
|
@ -586,23 +586,10 @@ def run_decode(
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
|
|
@ -610,17 +597,6 @@ def run_decode(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
|
||||
|
|
@ -839,23 +815,10 @@ def run_prefill(
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
|
|
@ -863,17 +826,6 @@ def run_prefill(
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
|||
vocab_size = model.config.vocab_size
|
||||
model_norm = model.model.norm
|
||||
lm_head = model.lm_head
|
||||
if n_splits_linear == 1:
|
||||
weights = [(lm_head.weight, lm_head.scale)]
|
||||
else:
|
||||
lm_heads = lm_head.lm_heads
|
||||
lm_head_weights = []
|
||||
scales = []
|
||||
for l in lm_heads:
|
||||
lm_head_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights = [(torch.stack(lm_head_weights, axis=0),
|
||||
torch.stack(scales, axis=0))]
|
||||
if isinstance(weights[0], tuple):
|
||||
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
|
||||
else: # FP16 Linear
|
||||
|
|
@ -44,13 +54,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
|
|||
dtype=np_dtype,
|
||||
model_norm_weight=model_norm.weight.to(torch.float16),
|
||||
vocab_size=vocab_size,
|
||||
n_splits=n_splits_linear
|
||||
)
|
||||
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
|
||||
|
||||
# save weights bins files
|
||||
if n_splits_linear == 1:
|
||||
weight_numpy = [
|
||||
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
|
||||
]
|
||||
else:
|
||||
weight_numpy = [v.numpy() for v in weights[0]]
|
||||
|
||||
for idx, weight in enumerate(weight_numpy):
|
||||
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
|
||||
|
|
@ -83,17 +97,15 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
weights = [
|
||||
(attn_layer.W_pack.weight, attn_layer.W_pack.scale),
|
||||
(attn_layer.o_proj.weight, attn_layer.o_proj.scale),
|
||||
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale),
|
||||
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale),
|
||||
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale),
|
||||
]
|
||||
else:
|
||||
# TODO
|
||||
pass
|
||||
for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
@ -119,6 +131,9 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
|
|||
mode="decode",
|
||||
transpose_value=transpose_value_cache,
|
||||
dtype=np_dtype,
|
||||
n_splits_linear=n_splits_linear,
|
||||
n_splits_down_proj=n_splits_down_proj,
|
||||
group_size=group_size
|
||||
)
|
||||
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
|
||||
f"decoder_layer_{layer_idx}",
|
||||
|
|
|
|||
|
|
@ -91,21 +91,21 @@ class LowBitLLMLMHead(LLMBaseNNFactory):
|
|||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
# define input, the order self.parameter matters
|
||||
if n_splits == 1:
|
||||
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||
else:
|
||||
input = self.create_input_op((1, self.batch_size, self.hidden_size))
|
||||
|
||||
hidden_states = input
|
||||
|
||||
# model norm and lm head
|
||||
model_norm_weight = self.constant(model_norm_weight)
|
||||
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
|
||||
if n_splits == 1:
|
||||
|
||||
hidden_states = self.linear(
|
||||
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states = self.dq_split_linear(
|
||||
hidden_states, self.vocab_size, self.hidden_size, n_splits,
|
||||
wt_dtype=dtype, scale_factor=False
|
||||
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
||||
n_splits=n_splits,
|
||||
scale_factor=(n_splits == 1),
|
||||
)
|
||||
|
||||
# define outputs
|
||||
|
|
|
|||
|
|
@ -174,40 +174,15 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0),
|
||||
torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
|
||||
|
|
|
|||
|
|
@ -109,36 +109,22 @@ class MiniCPMLMHead(LLMBaseNNFactory):
|
|||
hidden_states = self.layer_norm(hidden_states, model_norm_weight)
|
||||
if vocab_size == 122753:
|
||||
# for MiniCPM-2B-sft-bf16
|
||||
if n_splits == 1:
|
||||
hidden_states_1 = self.linear(
|
||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
||||
n_splits=n_splits, scale_factor=(n_splits == 1)
|
||||
)
|
||||
hidden_states_2 = self.linear(
|
||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states_1 = self.dq_split_linear(
|
||||
hidden_states, 73440, self.hidden_size,
|
||||
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
|
||||
)
|
||||
hidden_states_2 = self.dq_split_linear(
|
||||
hidden_states, 73440, self.hidden_size,
|
||||
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
|
||||
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
|
||||
n_splits=n_splits, scale_factor=(n_splits == 1)
|
||||
)
|
||||
|
||||
hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
|
||||
hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
|
||||
else:
|
||||
# for MiniCPM-1B-sft-bf16
|
||||
if n_splits == 1:
|
||||
hidden_states = self.linear(
|
||||
hidden_states, self.vocab_size, self.hidden_size, bias=False,
|
||||
wt_dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states = self.dq_split_linear(
|
||||
hidden_states, self.vocab_size, self.hidden_size,
|
||||
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
|
||||
wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1)
|
||||
)
|
||||
|
||||
# define outputs
|
||||
|
|
@ -245,40 +231,15 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0),
|
||||
torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
|
|
|
|||
|
|
@ -99,23 +99,10 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
|
||||
mlp_layer.down_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
|
|
@ -123,17 +110,6 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
|
|||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
|
||||
k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)
|
||||
v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)
|
||||
|
|
|
|||
Loading…
Reference in a new issue