Groupwise prefill optimization (#12291)
* except lm_head * remove * support gw lm_head * update * fix * remove run.bat * fix style * support llama3 * slice -> split * remove debug * fix style * add dpu
This commit is contained in:
parent
540eaeb12c
commit
70037ad55f
3 changed files with 105 additions and 147 deletions
|
|
@ -188,7 +188,10 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
|||
new_value_states = self.convert_to_fp16(curr_key_values[i][1])
|
||||
|
||||
print("start compiling")
|
||||
self.compile()
|
||||
if mode == "prefill":
|
||||
self.compile(npu_dpu_groups=6)
|
||||
else:
|
||||
self.compile()
|
||||
|
||||
def build_decoder(
|
||||
self,
|
||||
|
|
@ -753,19 +756,40 @@ def run_prefill(
|
|||
|
||||
weights = []
|
||||
|
||||
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
for l in attn_layer.o_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
|
|||
|
|
@ -165,60 +165,21 @@ class LLMBaseNNFactory(NNFactory):
|
|||
)
|
||||
else:
|
||||
hidden_states = self.unsqueeze(hidden_states, axis=0)
|
||||
if mode == "prefill":
|
||||
query_states_to_concat = []
|
||||
key_states_to_concat = []
|
||||
value_states_to_concat = []
|
||||
for i in range(self.n_splits_linear):
|
||||
sub_hidden_states = self.slice(hidden_states,
|
||||
begin=[0, 0, i * groupsize],
|
||||
end=[1, seq_len, (i + 1) * groupsize])
|
||||
query_states_to_concat.append(
|
||||
self.linear(
|
||||
sub_hidden_states,
|
||||
num_heads * head_dim,
|
||||
groupsize,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
key_states_to_concat.append(
|
||||
self.linear(
|
||||
sub_hidden_states,
|
||||
num_key_value_heads * head_dim,
|
||||
groupsize,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
value_states_to_concat.append(
|
||||
self.linear(
|
||||
sub_hidden_states,
|
||||
num_key_value_heads * head_dim,
|
||||
groupsize,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
query_states = sum(query_states_to_concat)
|
||||
key_states = sum(key_states_to_concat)
|
||||
value_states = sum(value_states_to_concat)
|
||||
else:
|
||||
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
|
||||
hidden_size, self.n_splits_linear,
|
||||
wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
|
||||
if q_bias is not None:
|
||||
query_states = query_states + q_bias
|
||||
|
|
@ -296,23 +257,10 @@ class LLMBaseNNFactory(NNFactory):
|
|||
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
if mode == "prefill":
|
||||
attn_output_to_concat = []
|
||||
for i in range(self.n_splits_linear):
|
||||
sub_attn_output = self.slice(attn_output,
|
||||
begin=[0, 0, i * groupsize],
|
||||
end=[1, seq_len, (i + 1) * groupsize])
|
||||
attn_output_to_concat.append(
|
||||
self.linear(
|
||||
sub_attn_output, hidden_size, groupsize, bias=False,
|
||||
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
attn_output = sum(attn_output_to_concat)
|
||||
else:
|
||||
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
|
||||
return attn_output, new_key_states, new_value_states
|
||||
|
||||
|
|
@ -488,37 +436,14 @@ class LLMBaseNNFactory(NNFactory):
|
|||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
else:
|
||||
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
|
||||
if mode == "prefill":
|
||||
gate_up_groupsize = self.hidden_size // self.n_splits_linear
|
||||
mm1_to_concat = []
|
||||
mm2_to_concat = []
|
||||
for i in range(self.n_splits_linear):
|
||||
sub_hidden_states = self.slice(hidden_states,
|
||||
begin=[0, 0, i * gate_up_groupsize],
|
||||
end=[1, seq_len, (i + 1) * gate_up_groupsize])
|
||||
mm1_to_concat.append(
|
||||
self.linear(
|
||||
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
mm2_to_concat.append(
|
||||
self.linear(
|
||||
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
|
||||
bias=False,
|
||||
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
mm1 = sum(mm1_to_concat)
|
||||
mm2 = sum(mm2_to_concat)
|
||||
else:
|
||||
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
|
||||
self.n_splits_linear, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
|
||||
|
||||
if self.n_splits_down_proj == 1:
|
||||
|
|
@ -527,23 +452,10 @@ class LLMBaseNNFactory(NNFactory):
|
|||
)
|
||||
else:
|
||||
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
|
||||
if mode == "prefill":
|
||||
down_groupsize = self.intermediate_size // self.n_splits_down_proj
|
||||
hidden_states_to_concat = []
|
||||
for i in range(self.n_splits_down_proj):
|
||||
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
|
||||
end=[1, seq_len, (i + 1) * down_groupsize])
|
||||
hidden_states_to_concat.append(
|
||||
self.linear(
|
||||
sub_mm1, self.hidden_size, down_groupsize, bias=False,
|
||||
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
|
||||
)
|
||||
)
|
||||
hidden_states = sum(hidden_states_to_concat)
|
||||
else:
|
||||
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
|
||||
self.n_splits_down_proj, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0))
|
||||
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
|
||||
self.n_splits_down_proj, wt_dtype=self.dtype,
|
||||
scale_factor=(self.group_size == 0),
|
||||
is_prefill=(mode == "prefill"))
|
||||
return hidden_states
|
||||
|
||||
def layer_norm(self, hidden_states, layernorm_weight):
|
||||
|
|
@ -660,9 +572,11 @@ class LLMBaseNNFactory(NNFactory):
|
|||
n_splits: int,
|
||||
act_dtype: npt.DTypeLike = np.float16,
|
||||
wt_dtype: npt.DTypeLike = np.float16,
|
||||
scale_factor: bool = False):
|
||||
scale_factor: bool = False,
|
||||
is_prefill: bool = False):
|
||||
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
|
||||
False, act_dtype, wt_dtype, scale_factor)
|
||||
False, act_dtype, wt_dtype, scale_factor,
|
||||
is_prefill=is_prefill)
|
||||
self.linear_ops.append(op)
|
||||
return op
|
||||
|
||||
|
|
|
|||
|
|
@ -827,20 +827,40 @@ def run_prefill(
|
|||
mlp_layer = curr_layer.mlp
|
||||
|
||||
weights = []
|
||||
if n_splits_linear == 1:
|
||||
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
|
||||
attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list,
|
||||
attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list,
|
||||
mlp_layer.up_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
weights.append((o.weight, o.scale))
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
else:
|
||||
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
|
||||
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in layer_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
|
||||
attn_layer.v_proj_dq_list):
|
||||
weights.append((q.weight, q.scale))
|
||||
weights.append((k.weight, k.scale))
|
||||
weights.append((v.weight, v.scale))
|
||||
|
||||
for l in attn_layer.o_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list):
|
||||
weights.append((g.weight, g.scale))
|
||||
weights.append((u.weight, u.scale))
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
if n_splits_down_proj == 1:
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
weights.append((l.weight, l.scale))
|
||||
else:
|
||||
l_weights = []
|
||||
scales = []
|
||||
for l in mlp_layer.down_proj_dq_list:
|
||||
l_weights.append(l.weight)
|
||||
scales.append(l.scale)
|
||||
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
|
||||
|
||||
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
|
||||
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)
|
||||
|
|
|
|||
Loading…
Reference in a new issue