[NPU] Support Baichuan groupwise & gw code refactor (#12337)

* support minicpm 1b & qwen 1.5b gw

* support minicpm 1b

* baichuan part

* update

* support minicpm 1b & qwen 1.5b gw

* support minicpm 1b

* baichuan part

* update

* update

* update

* baichuan support

* code refactor

* remove code

* fix style

* address comments

* revert
This commit is contained in:
Yina Chen 2024-11-08 05:42:42 +02:00 committed by GitHub
parent 812d5cc32e
commit b2e69a896c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 367 additions and 434 deletions

View file

@ -60,6 +60,7 @@ if __name__ == "__main__":
parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict") parser.add_argument("--n-predict", type=int, default=32, help="Max tokens to predict")
parser.add_argument("--max-context-len", type=int, default=1024) parser.add_argument("--max-context-len", type=int, default=1024)
parser.add_argument("--max-prompt-len", type=int, default=512) parser.add_argument("--max-prompt-len", type=int, default=512)
parser.add_argument("--quantization_group_size", type=int, default=0)
parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False) parser.add_argument("--disable-transpose-value-cache", action="store_true", default=False)
parser.add_argument("--disable-streaming", action="store_true", default=False) parser.add_argument("--disable-streaming", action="store_true", default=False)
@ -72,6 +73,7 @@ if __name__ == "__main__":
pipeline=True, pipeline=True,
max_context_len=args.max_context_len, max_context_len=args.max_context_len,
max_prompt_len=args.max_prompt_len, max_prompt_len=args.max_prompt_len,
quantization_group_size=args.quantization_group_size,
torch_dtype=torch.float16, torch_dtype=torch.float16,
attn_implementation="eager", attn_implementation="eager",
transpose_value_cache=not args.disable_transpose_value_cache, transpose_value_cache=not args.disable_transpose_value_cache,

View file

@ -50,6 +50,9 @@ from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from ipex_llm.transformers.npu_models.mp_models_base import run_model from ipex_llm.transformers.npu_models.mp_models_base import run_model
from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory from ipex_llm.transformers.npu_models.mp_models_base import LLMBaseNNFactory
from ipex_llm.transformers.npu_models.common import reshape_lm_head_input
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory): class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
@ -75,12 +78,18 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
device: str = "NPU", device: str = "NPU",
rms_norm_eps, rms_norm_eps,
intermediate_size, intermediate_size,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
): ):
super().__init__(max_seq_len=max_seq_len, super().__init__(max_seq_len=max_seq_len,
transpose_value=transpose_value, transpose_value=transpose_value,
dtype=dtype, dtype=dtype,
profile=profile, profile=profile,
device=device) device=device,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.dtype = dtype self.dtype = dtype
@ -115,8 +124,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1), attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
dtype=np.int64) dtype=np.int64)
else: else:
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len), attention_mask = None
dtype=np.int64)
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64) position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
# self.num_key_value_heads = num_key_value_heads # self.num_key_value_heads = num_key_value_heads
@ -178,6 +186,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
post_attention_layernorm_weight=post_attn_layernorm_weights[i], post_attention_layernorm_weight=post_attn_layernorm_weights[i],
past_key=past_keys[i], past_key=past_keys[i],
past_value=past_values[i], past_value=past_values[i],
use_prefill_sdp=True,
) )
curr_key_values.append((new_key_states, new_value_states)) curr_key_values.append((new_key_states, new_value_states))
@ -189,6 +198,9 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
new_value_states = self.convert_to_fp16(curr_key_values[i][1]) new_value_states = self.convert_to_fp16(curr_key_values[i][1])
print("start compiling") print("start compiling")
if mode == "prefill" and os.environ.get("IPEX_LLM_NPU_DISABLE_COMPILE_OPT", "0") != "1":
self.compile(npu_dpu_groups=6)
else:
self.compile() self.compile()
def attention(self, def attention(self,
@ -206,15 +218,23 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
seq_len, seq_len,
q_bias=None, q_bias=None,
k_bias=None, k_bias=None,
v_bias=None): v_bias=None,
use_prefill_sdp=False):
hidden_size = num_heads * head_dim hidden_size = num_heads * head_dim
if self.n_splits_linear != 1:
hidden_states = self.unsqueeze(hidden_states, axis=0)
proj = self.linear( proj = self.linear(
hidden_states, hidden_states,
3 * hidden_size, 3 * hidden_size,
hidden_size, hidden_size,
bias=False, bias=False,
wt_dtype=self.dtype wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) )
proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h proj = self.reshape(proj, [-1, 3, hidden_size]) # b*s, 3, h
proj = self.unsqueeze(proj, [0]) # b, s, 3, h proj = self.unsqueeze(proj, [0]) # b, s, 3, h
proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h proj = self.transpose(proj, [2, 1, 0, 3]) # 3, s, b, h
@ -224,8 +244,14 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim]) key_states = self.reshape(proj[1, ...], [1, self.seq_len, num_heads, head_dim])
key_states = self.transpose(key_states, [0, 2, 1, 3]) key_states = self.transpose(key_states, [0, 2, 1, 3])
value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim]) value_states = self.reshape(proj[2, ...], [1, self.seq_len, num_heads, head_dim])
use_ov_sdp = (mode == "prefill") and use_prefill_sdp
if self.transpose_value: if self.transpose_value:
value_states = self.transpose(value_states, [0, 2, 3, 1]) new_value_states = self.transpose(value_states, [0, 2, 3, 1])
if use_ov_sdp:
value_states = self.transpose(value_states, [0, 2, 1, 3])
else:
value_states = new_value_states
else: else:
value_states = self.transpose(value_states, [0, 2, 1, 3]) value_states = self.transpose(value_states, [0, 2, 1, 3])
@ -243,7 +269,6 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
head_dim=head_dim, head_dim=head_dim,
) )
new_key_states = key_states new_key_states = key_states
new_value_states = value_states
if self.mode == "decode": if self.mode == "decode":
key_states = self.concat(past_key, key_states, axis=-2) key_states = self.concat(past_key, key_states, axis=-2)
@ -252,6 +277,14 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
else: else:
value_states = self.concat(past_value, value_states, axis=-2) value_states = self.concat(past_value, value_states, axis=-2)
if use_ov_sdp:
value_states = self.convert_to_fp32(value_states)
key_states = self.convert_to_fp32(key_states)
query_states = self.convert_to_fp32(query_states)
attn_output = self.scaled_dot_product_attention(
query_states, key_states, value_states, None, True)
attn_output = self.convert_to_fp16(attn_output)
else:
attn_weight = self.matmul(query_states, key_states, False, True) / ( attn_weight = self.matmul(query_states, key_states, False, True) / (
math.sqrt(self.head_dim)) math.sqrt(self.head_dim))
attention_mask = self.convert_to_fp16(attention_mask) attention_mask = self.convert_to_fp16(attention_mask)
@ -265,7 +298,10 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
attn_output = self.linear( attn_output = self.linear(
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) )
return attn_output, new_key_states, new_value_states return attn_output, new_key_states, new_value_states
@ -278,6 +314,7 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
post_attention_layernorm_weight, post_attention_layernorm_weight,
past_key=None, past_key=None,
past_value=None, past_value=None,
use_prefill_sdp=False,
): ):
residual = hidden_states residual = hidden_states
@ -298,12 +335,13 @@ class LowBitBaichuanMultiDecoderlayer(LLMBaseNNFactory):
num_heads=self.num_heads, num_heads=self.num_heads,
head_dim=self.head_dim, head_dim=self.head_dim,
seq_len=self.seq_len, seq_len=self.seq_len,
use_prefill_sdp=use_prefill_sdp,
) )
hidden_states = self.eltwise_add(residual, attn_output) hidden_states = self.eltwise_add(residual, attn_output)
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight) hidden_states = self.layer_norm(hidden_states, post_attention_layernorm_weight)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states, self.seq_len, self.mode)
hidden_states = self.eltwise_add(residual, hidden_states) hidden_states = self.eltwise_add(residual, hidden_states)
hidden_states = self.convert_to_fp16(hidden_states) hidden_states = self.convert_to_fp16(hidden_states)
@ -329,6 +367,9 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
max_seq_len: int = 1024, max_seq_len: int = 1024,
transpose_value: bool = False, transpose_value: bool = False,
do_print: bool = False, do_print: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
): ):
super().__init__() super().__init__()
@ -338,6 +379,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
for w in parameters: for w in parameters:
if isinstance(w, tuple): # from QuantizedLinear if isinstance(w, tuple): # from QuantizedLinear
op_parameters.append((w[0].numpy(), w[1].numpy())) op_parameters.append((w[0].numpy(), w[1].numpy()))
elif w.dtype in [torch.int8, torch.uint8]: # QuantizedLinear weight
op_parameters.append(w.numpy())
elif isinstance(w, np.ndarray): # scale
op_parameters.append(w)
else: else:
op_parameters.append(w.to(torch.float16).numpy()) op_parameters.append(w.to(torch.float16).numpy())
self.op_parameters = op_parameters self.op_parameters = op_parameters
@ -346,6 +391,10 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
self.transpose_value = transpose_value self.transpose_value = transpose_value
if isinstance(parameters[0], tuple): if isinstance(parameters[0], tuple):
np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8 np_dtype = np.int8 if parameters[0][0].dtype == torch.int8 else np.uint8
elif parameters[0].dtype == torch.int8:
np_dtype = np.int8
elif parameters[0].dtype == torch.uint8:
np_dtype = np.uint8
else: # FP16 Linear else: # FP16 Linear
np_dtype = np.float16 np_dtype = np.float16
@ -380,6 +429,9 @@ class FusedBaichuanLowBitMultiDecoderlayer(torch.nn.Module):
mode="decode", mode="decode",
transpose_value=self.transpose_value, transpose_value=self.transpose_value,
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
self.backend_decoders.append(decoder) self.backend_decoders.append(decoder)
@ -453,6 +505,9 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
intermediate_size, intermediate_size,
max_seq_len: int = 128, max_seq_len: int = 128,
transpose_value: bool = False, transpose_value: bool = False,
n_splits_linear: int = 1,
n_splits_down_proj: int = 1,
group_size: int = 0
): ):
super().__init__() super().__init__()
self.op_parameters = parameters self.op_parameters = parameters
@ -481,6 +536,9 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
mode="prefill", mode="prefill",
transpose_value=self.transpose_value, transpose_value=self.transpose_value,
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
self.layer_norm_0 = layer_norm_0 self.layer_norm_0 = layer_norm_0
self.layer_norm_1 = layer_norm_1 self.layer_norm_1 = layer_norm_1
@ -507,7 +565,6 @@ class FusedBaichuanLowBitDecoderlayer(torch.nn.Module):
backend_cls = self.backend_cls_prefill backend_cls = self.backend_cls_prefill
inputs = (hidden_states.to(torch.float16), inputs = (hidden_states.to(torch.float16),
attention_mask.to(torch.int64),
position_ids.to(torch.int64)) position_ids.to(torch.int64))
inputs += (self.layer_norm_0, self.layer_norm_1) inputs += (self.layer_norm_0, self.layer_norm_1)
hidden_states, past_key, past_value = run_model( hidden_states, past_key, past_value = run_model(
@ -557,22 +614,28 @@ def run_decode(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
group_size = getattr(model.config, "group_size", 0)
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end) layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs: for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [ weights = []
(attn_layer.W_pack.weight, attn_layer.W_pack.scale), for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), mlp_layer.down_proj_dq_list]:
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), l_weights = []
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), scales = []
] for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
@ -599,6 +662,9 @@ def run_decode(
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
transpose_value=transpose_value_cache, transpose_value=transpose_value_cache,
do_print=False, do_print=False,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
dist.barrier() dist.barrier()
@ -754,23 +820,29 @@ def run_prefill(
head_dim = model.model.layers[layer_start].self_attn.head_dim head_dim = model.model.layers[layer_start].self_attn.head_dim
rms_norm_eps = model.config.rms_norm_eps rms_norm_eps = model.config.rms_norm_eps
intermediate_size = model.config.intermediate_size intermediate_size = model.config.intermediate_size
group_size = getattr(model.config, "group_size", 0)
deocderlayers = [] deocderlayers = []
layer_weights = [] layer_weights = []
input_layer_norm_weights = [] input_layer_norm_weights = []
post_attn_layernorm_weights = [] post_attn_layernorm_weights = []
layer_indexs = range(layer_start, layer_end) layer_indexs = range(layer_start, layer_end)
n_splits_linear = len(model.model.layers[0].mlp.gate_proj_dq_list)
n_splits_down_proj = len(model.model.layers[0].mlp.down_proj_dq_list)
for layer_idx in layer_indexs: for layer_idx in layer_indexs:
curr_layer = model.model.layers[layer_idx] curr_layer = model.model.layers[layer_idx]
attn_layer = curr_layer.self_attn attn_layer = curr_layer.self_attn
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [ weights = []
(attn_layer.W_pack.weight, attn_layer.W_pack.scale), for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), mlp_layer.down_proj_dq_list]:
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), l_weights = []
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), scales = []
] for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
@ -791,6 +863,9 @@ def run_prefill(
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
max_seq_len=max_output_len, max_seq_len=max_output_len,
transpose_value=transpose_value_cache, transpose_value=transpose_value_cache,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
layer_weights.extend(weights) layer_weights.extend(weights)
@ -1025,3 +1100,71 @@ def gen_baichuan_fused_model_forward(prefill_runner, decode_runner):
) )
return baichuan_fused_model_forward return baichuan_fused_model_forward
def baichuan2_causal_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None \
else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
# ipex-llm change start
hidden_states = reshape_lm_head_input(hidden_states)
# ipex-llm change end
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
softmax_normalizer = shift_logits.max(-1).values ** 2
z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels) + z_loss
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View file

@ -75,10 +75,11 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention from transformers.models.qwen2.modeling_qwen2 import Qwen2MLP, Qwen2Attention
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"] attn_module_names = ["q_proj", "k_proj", "v_proj", "o_proj"]
baichuan_attn_module_names = ["W_pack", "o_proj"]
mlp_module_names = ["down_proj", "up_proj", "gate_proj"] mlp_module_names = ["down_proj", "up_proj", "gate_proj"]
if ( if (
isinstance(module, (Qwen2Attention, LlamaAttention)) isinstance(module, (Qwen2Attention, LlamaAttention))
or module.__class__.__name__ in ['MiniCPMAttention', 'Attention'] or module.__class__.__name__ in ['MiniCPMAttention']
): ):
for name in attn_module_names: for name in attn_module_names:
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name, setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
@ -97,3 +98,10 @@ def split_linears(module: torch.nn.Module, n_splits_hidden_size=2, n_splits_down
n_splits=n_splits_mlp, n_splits=n_splits_mlp,
load=load)) load=load))
delattr(module, name) delattr(module, name)
elif module.__class__.__name__ == 'Attention' and module.config.model_type == 'baichuan':
# baichuan attention
for name in baichuan_attn_module_names:
setattr(module, f"{name}_dq_list", split_linear(getattr(module, name), name,
n_splits=n_splits_hidden_size,
load=load))
delattr(module, name)

View file

@ -87,7 +87,7 @@ def optimize_llm_pre(model: torch.nn.Module, qtype, mixed_precision,
model.llm.config.model_type = "llama" model.llm.config.model_type = "llama"
model = model.llm model = model.llm
if model.config.model_type in ["qwen2", "llama", "minicpm"]: if model.config.model_type in ["qwen2", "llama", "minicpm", "baichuan"]:
from ipex_llm.transformers.npu_models.common import split_linears from ipex_llm.transformers.npu_models.common import split_linears
if quantization_group_size == 0: if quantization_group_size == 0:
n_splits_linear = 1 n_splits_linear = 1
@ -245,6 +245,8 @@ def convert_baichuan(
modeling_module_name = model.__class__.__module__ modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name) module = importlib.import_module(modeling_module_name)
convert_forward(model, module.BaichuanModel, baichuan_model_forward) convert_forward(model, module.BaichuanModel, baichuan_model_forward)
from ipex_llm.transformers.npu_models.baichuan_mp import baichuan2_causal_forward
convert_forward(model, module.BaichuanForCausalLM, baichuan2_causal_forward)
def convert_minicpm( def convert_minicpm(
@ -392,7 +394,7 @@ def optimize_llm(
if intra_pp is None: if intra_pp is None:
intra_pp = 2 intra_pp = 2
if inter_pp is None: if inter_pp is None:
inter_pp = 2 inter_pp = 2 if group_size == 0 else 4
convert_baichuan(model, convert_baichuan(model,
max_output_len=max_context_len, max_output_len=max_context_len,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,

View file

@ -560,23 +560,10 @@ def run_decode(
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
@ -584,17 +571,6 @@ def run_decode(
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
@ -844,40 +820,15 @@ def run_prefill(
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):

View file

@ -540,23 +540,10 @@ def run_decode(
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
@ -564,17 +551,6 @@ def run_decode(
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
@ -783,24 +759,10 @@ def run_prefill(
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
@ -808,17 +770,6 @@ def run_prefill(
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

View file

@ -138,13 +138,18 @@ class LLMBaseNNFactory(NNFactory):
use_prefill_sdp=False): use_prefill_sdp=False):
hidden_size = num_heads * head_dim hidden_size = num_heads * head_dim
num_key_value_groups = num_heads // num_key_value_heads num_key_value_groups = num_heads // num_key_value_heads
if self.n_splits_linear == 1: if self.n_splits_linear != 1:
hidden_states = self.unsqueeze(hidden_states, axis=0)
query_states = self.linear( query_states = self.linear(
hidden_states, hidden_states,
num_heads * head_dim, num_heads * head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
wt_dtype=self.dtype, wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) )
key_states = self.linear( key_states = self.linear(
@ -153,6 +158,9 @@ class LLMBaseNNFactory(NNFactory):
hidden_size, hidden_size,
bias=False, bias=False,
wt_dtype=self.dtype, wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) )
value_states = self.linear( value_states = self.linear(
@ -161,24 +169,10 @@ class LLMBaseNNFactory(NNFactory):
hidden_size, hidden_size,
bias=False, bias=False,
wt_dtype=self.dtype, wt_dtype=self.dtype,
n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) )
else:
hidden_states = self.unsqueeze(hidden_states, axis=0)
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
if q_bias is not None: if q_bias is not None:
query_states = query_states + q_bias query_states = query_states + q_bias
@ -263,15 +257,12 @@ class LLMBaseNNFactory(NNFactory):
attn_output = self.transpose(attn_output, [0, 2, 1, 3]) attn_output = self.transpose(attn_output, [0, 2, 1, 3])
attn_output = self.reshape(attn_output, [1, seq_len, hidden_size]) attn_output = self.reshape(attn_output, [1, seq_len, hidden_size])
if self.n_splits_linear == 1:
attn_output = self.linear( attn_output = self.linear(
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype,
) n_splits=self.n_splits_linear,
else:
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0), scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")) is_prefill=(mode == "prefill")
)
return attn_output, new_key_states, new_value_states return attn_output, new_key_states, new_value_states
def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias): def paraformer_layer_norm(self, hidden_states, layernorm_weight, layernorm_bias):
@ -434,38 +425,26 @@ class LLMBaseNNFactory(NNFactory):
return w_2 return w_2
def mlp(self, hidden_states, seq_len=-1, mode="prefill"): def mlp(self, hidden_states, seq_len=-1, mode="prefill"):
if self.n_splits_linear == 1:
mm1 = self.linear( mm1 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, hidden_states, self.intermediate_size, self.hidden_size, bias=False,
wt_dtype=self.dtype wt_dtype=self.dtype, n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) )
mm2 = self.linear( mm2 = self.linear(
hidden_states, self.intermediate_size, self.hidden_size, bias=False, hidden_states, self.intermediate_size, self.hidden_size, bias=False,
wt_dtype=self.dtype wt_dtype=self.dtype, n_splits=self.n_splits_linear,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")
) # type: ignore[attr-defined] ) # type: ignore[attr-defined]
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill"))
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.n_splits_down_proj == 1:
hidden_states = self.linear( hidden_states = self.linear(
mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype mm1, self.hidden_size, self.intermediate_size, bias=False, wt_dtype=self.dtype,
) n_splits=self.n_splits_down_proj,
else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0), scale_factor=(self.group_size == 0),
is_prefill=(mode == "prefill")) is_prefill=(mode == "prefill")
)
return hidden_states return hidden_states
def layer_norm(self, hidden_states, layernorm_weight): def layer_norm(self, hidden_states, layernorm_weight):
@ -571,8 +550,26 @@ class LLMBaseNNFactory(NNFactory):
self.input_ops.append(op) self.input_ops.append(op)
return op return op
def linear(self, *args, **kwargs): def linear(self,
op = super().linear(*args, **kwargs) input_node: ctypes._Pointer,
output_channels: int,
input_channels: int,
bias: Optional[bool] = False,
act_dtype: npt.DTypeLike = np.float16,
wt_dtype: npt.DTypeLike = np.float16,
n_splits: int = 1,
scale_factor: bool = True,
is_prefill: bool = False):
if n_splits == 1:
op = super().linear(input_node, output_channels,
input_channels, bias, act_dtype,
wt_dtype, scale_factor=scale_factor)
else:
op = super().dq_split_linear(input_node, n_splits,
output_channels, input_channels,
bias=bias, act_dtype=act_dtype,
wt_dtype=wt_dtype, scale_factor=scale_factor,
is_prefill=is_prefill)
self.linear_ops.append(op) self.linear_ops.append(op)
return op return op

View file

@ -586,23 +586,10 @@ def run_decode(
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
@ -610,17 +597,6 @@ def run_decode(
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16) layer_norm_0 = curr_layer.input_layernorm.weight.to(torch.float16)
@ -839,23 +815,10 @@ def run_prefill(
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
@ -863,17 +826,6 @@ def run_prefill(
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

View file

@ -28,7 +28,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
vocab_size = model.config.vocab_size vocab_size = model.config.vocab_size
model_norm = model.model.norm model_norm = model.model.norm
lm_head = model.lm_head lm_head = model.lm_head
if n_splits_linear == 1:
weights = [(lm_head.weight, lm_head.scale)] weights = [(lm_head.weight, lm_head.scale)]
else:
lm_heads = lm_head.lm_heads
lm_head_weights = []
scales = []
for l in lm_heads:
lm_head_weights.append(l.weight)
scales.append(l.scale)
weights = [(torch.stack(lm_head_weights, axis=0),
torch.stack(scales, axis=0))]
if isinstance(weights[0], tuple): if isinstance(weights[0], tuple):
np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8 np_dtype = np.int8 if weights[0][0].dtype == torch.int8 else np.uint8
else: # FP16 Linear else: # FP16 Linear
@ -44,13 +54,17 @@ def convert_lm_head_and_embedding(model, n_splits_linear, temp_dir, weight_dir):
dtype=np_dtype, dtype=np_dtype,
model_norm_weight=model_norm.weight.to(torch.float16), model_norm_weight=model_norm.weight.to(torch.float16),
vocab_size=vocab_size, vocab_size=vocab_size,
n_splits=n_splits_linear
) )
last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir) last_blob_path = update_names_of_IR_and_export_blob(new_lm_head, "lm_head", temp_dir)
# save weights bins files # save weights bins files
if n_splits_linear == 1:
weight_numpy = [ weight_numpy = [
lm_head.weight.data.numpy(), lm_head.scale.data.numpy(), lm_head.weight.data.numpy(), lm_head.scale.data.numpy(),
] ]
else:
weight_numpy = [v.numpy() for v in weights[0]]
for idx, weight in enumerate(weight_numpy): for idx, weight in enumerate(weight_numpy):
bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin") bin_file = os.path.join(weight_dir, f"model_lm_head_input_{1+idx}.bin")
@ -83,17 +97,15 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1: for layer_list in [attn_layer.W_pack_dq_list, attn_layer.o_proj_dq_list,
weights = [ mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
(attn_layer.W_pack.weight, attn_layer.W_pack.scale), mlp_layer.down_proj_dq_list]:
(attn_layer.o_proj.weight, attn_layer.o_proj.scale), l_weights = []
(mlp_layer.gate_proj.weight, mlp_layer.gate_proj.scale), scales = []
(mlp_layer.up_proj.weight, mlp_layer.up_proj.scale), for l in layer_list:
(mlp_layer.down_proj.weight, mlp_layer.down_proj.scale), l_weights.append(l.weight)
] scales.append(l.scale)
else: weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
# TODO
pass
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
@ -119,6 +131,9 @@ def convert_baichuan_layer(model, layer_idx, n_splits_linear, n_splits_down_proj
mode="decode", mode="decode",
transpose_value=transpose_value_cache, transpose_value=transpose_value_cache,
dtype=np_dtype, dtype=np_dtype,
n_splits_linear=n_splits_linear,
n_splits_down_proj=n_splits_down_proj,
group_size=group_size
) )
rest_blob_path = update_names_of_IR_and_export_blob(single_decoder, rest_blob_path = update_names_of_IR_and_export_blob(single_decoder,
f"decoder_layer_{layer_idx}", f"decoder_layer_{layer_idx}",

View file

@ -91,21 +91,21 @@ class LowBitLLMLMHead(LLMBaseNNFactory):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
# define input, the order self.parameter matters # define input, the order self.parameter matters
if n_splits == 1:
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size)) input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
else:
input = self.create_input_op((1, self.batch_size, self.hidden_size))
hidden_states = input hidden_states = input
# model norm and lm head # model norm and lm head
model_norm_weight = self.constant(model_norm_weight) model_norm_weight = self.constant(model_norm_weight)
hidden_states = self.layer_norm(hidden_states, model_norm_weight) hidden_states = self.layer_norm(hidden_states, model_norm_weight)
if n_splits == 1:
hidden_states = self.linear( hidden_states = self.linear(
hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype hidden_states, self.vocab_size, self.hidden_size, bias=False, wt_dtype=self.dtype,
) n_splits=n_splits,
else: scale_factor=(n_splits == 1),
hidden_states = self.dq_split_linear(
hidden_states, self.vocab_size, self.hidden_size, n_splits,
wt_dtype=dtype, scale_factor=False
) )
# define outputs # define outputs

View file

@ -174,40 +174,15 @@ def convert_llama_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"): if hasattr(curr_layer.self_attn.rotary_emb, "cos_cached"):

View file

@ -109,36 +109,22 @@ class MiniCPMLMHead(LLMBaseNNFactory):
hidden_states = self.layer_norm(hidden_states, model_norm_weight) hidden_states = self.layer_norm(hidden_states, model_norm_weight)
if vocab_size == 122753: if vocab_size == 122753:
# for MiniCPM-2B-sft-bf16 # for MiniCPM-2B-sft-bf16
if n_splits == 1:
hidden_states_1 = self.linear( hidden_states_1 = self.linear(
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
n_splits=n_splits, scale_factor=(n_splits == 1)
) )
hidden_states_2 = self.linear( hidden_states_2 = self.linear(
hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype hidden_states, 73440, self.hidden_size, bias=False, wt_dtype=self.dtype,
) n_splits=n_splits, scale_factor=(n_splits == 1)
else:
hidden_states_1 = self.dq_split_linear(
hidden_states, 73440, self.hidden_size,
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
)
hidden_states_2 = self.dq_split_linear(
hidden_states, 73440, self.hidden_size,
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
) )
hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313]) hidden_states_2 = self.slice(hidden_states_2, begin=[0, 0, 0], end=[1, 1, 49313])
hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2) hidden_states = self.concat(hidden_states_1, hidden_states_2, axis=2)
else: else:
# for MiniCPM-1B-sft-bf16 # for MiniCPM-1B-sft-bf16
if n_splits == 1:
hidden_states = self.linear( hidden_states = self.linear(
hidden_states, self.vocab_size, self.hidden_size, bias=False, hidden_states, self.vocab_size, self.hidden_size, bias=False,
wt_dtype=self.dtype wt_dtype=self.dtype, n_splits=n_splits, scale_factor=(n_splits == 1)
)
else:
hidden_states = self.dq_split_linear(
hidden_states, self.vocab_size, self.hidden_size,
n_splits=n_splits, wt_dtype=dtype, scale_factor=False
) )
# define outputs # define outputs
@ -245,40 +231,15 @@ def convert_minicpm_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
l_weights.append(l.weight) l_weights.append(l.weight)
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0),
torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)

View file

@ -99,23 +99,10 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list, attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]: mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list,
mlp_layer.down_proj_dq_list]:
l_weights = [] l_weights = []
scales = [] scales = []
for l in layer_list: for l in layer_list:
@ -123,17 +110,6 @@ def convert_qwen_layer(model, layer_idx, n_splits_linear, n_splits_down_proj,
scales.append(l.scale) scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0))) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
if n_splits_down_proj == 1:
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
else:
l_weights = []
scales = []
for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16) q_bias = attn_layer.q_proj_dq_list.q_proj_dq_0.bias.to(torch.float16)
k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16) k_bias = attn_layer.k_proj_dq_list.k_proj_dq_0.bias.to(torch.float16)
v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16) v_bias = attn_layer.v_proj_dq_list.v_proj_dq_0.bias.to(torch.float16)