Groupwise prefill optimization (#12291)

* except lm_head

* remove

* support gw lm_head

* update

* fix

* remove run.bat

* fix style

* support llama3

* slice -> split

* remove debug

* fix style

* add dpu
This commit is contained in:
Yina Chen 2024-10-30 08:59:45 +02:00 committed by GitHub
parent 540eaeb12c
commit 70037ad55f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 105 additions and 147 deletions

View file

@ -188,7 +188,10 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
new_value_states = self.convert_to_fp16(curr_key_values[i][1]) new_value_states = self.convert_to_fp16(curr_key_values[i][1])
print("start compiling") print("start compiling")
self.compile() if mode == "prefill":
self.compile(npu_dpu_groups=6)
else:
self.compile()
def build_decoder( def build_decoder(
self, self,
@ -753,19 +756,40 @@ def run_prefill(
weights = [] weights = []
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, if n_splits_linear == 1:
attn_layer.v_proj_dq_list): for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
weights.append((q.weight, q.scale)) attn_layer.k_proj_dq_list,
weights.append((k.weight, k.scale)) attn_layer.v_proj_dq_list,
weights.append((v.weight, v.scale)) attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
for l in attn_layer.o_proj_dq_list: if n_splits_down_proj == 1:
weights.append((l.weight, l.scale)) for l in mlp_layer.down_proj_dq_list:
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): weights.append((l.weight, l.scale))
weights.append((g.weight, g.scale)) else:
weights.append((u.weight, u.scale)) l_weights = []
for l in mlp_layer.down_proj_dq_list: scales = []
weights.append((l.weight, l.scale)) for l in mlp_layer.down_proj_dq_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)

View file

@ -165,60 +165,21 @@ class LLMBaseNNFactory(NNFactory):
) )
else: else:
hidden_states = self.unsqueeze(hidden_states, axis=0) hidden_states = self.unsqueeze(hidden_states, axis=0)
if mode == "prefill": query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
query_states_to_concat = [] hidden_size, self.n_splits_linear,
key_states_to_concat = [] wt_dtype=self.dtype,
value_states_to_concat = [] scale_factor=(self.group_size == 0),
for i in range(self.n_splits_linear): is_prefill=(mode == "prefill"))
sub_hidden_states = self.slice(hidden_states, key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
begin=[0, 0, i * groupsize], hidden_size, self.n_splits_linear,
end=[1, seq_len, (i + 1) * groupsize]) wt_dtype=self.dtype,
query_states_to_concat.append( scale_factor=(self.group_size == 0),
self.linear( is_prefill=(mode == "prefill"))
sub_hidden_states, value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
num_heads * head_dim, hidden_size, self.n_splits_linear,
groupsize, wt_dtype=self.dtype,
bias=False, scale_factor=(self.group_size == 0),
wt_dtype=self.dtype, is_prefill=(mode == "prefill"))
scale_factor=(self.group_size == 0)
)
)
key_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
value_states_to_concat.append(
self.linear(
sub_hidden_states,
num_key_value_heads * head_dim,
groupsize,
bias=False,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0)
)
)
query_states = sum(query_states_to_concat)
key_states = sum(key_states_to_concat)
value_states = sum(value_states_to_concat)
else:
query_states = self.dq_split_linear(hidden_states, num_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
key_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
value_states = self.dq_split_linear(hidden_states, num_key_value_heads * head_dim,
hidden_size, self.n_splits_linear,
wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
if q_bias is not None: if q_bias is not None:
query_states = query_states + q_bias query_states = query_states + q_bias
@ -296,23 +257,10 @@ class LLMBaseNNFactory(NNFactory):
attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype attn_output, hidden_size, hidden_size, bias=False, wt_dtype=self.dtype
) )
else: else:
if mode == "prefill": attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
attn_output_to_concat = [] self.n_splits_linear, wt_dtype=self.dtype,
for i in range(self.n_splits_linear): scale_factor=(self.group_size == 0),
sub_attn_output = self.slice(attn_output, is_prefill=(mode == "prefill"))
begin=[0, 0, i * groupsize],
end=[1, seq_len, (i + 1) * groupsize])
attn_output_to_concat.append(
self.linear(
sub_attn_output, hidden_size, groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
attn_output = sum(attn_output_to_concat)
else:
attn_output = self.dq_split_linear(attn_output, hidden_size, hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
return attn_output, new_key_states, new_value_states return attn_output, new_key_states, new_value_states
@ -488,37 +436,14 @@ class LLMBaseNNFactory(NNFactory):
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
else: else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill": mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
gate_up_groupsize = self.hidden_size // self.n_splits_linear self.n_splits_linear, wt_dtype=self.dtype,
mm1_to_concat = [] scale_factor=(self.group_size == 0),
mm2_to_concat = [] is_prefill=(mode == "prefill"))
for i in range(self.n_splits_linear): mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
sub_hidden_states = self.slice(hidden_states, self.n_splits_linear, wt_dtype=self.dtype,
begin=[0, 0, i * gate_up_groupsize], scale_factor=(self.group_size == 0),
end=[1, seq_len, (i + 1) * gate_up_groupsize]) is_prefill=(mode == "prefill"))
mm1_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm2_to_concat.append(
self.linear(
sub_hidden_states, self.intermediate_size, gate_up_groupsize,
bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
mm1 = sum(mm1_to_concat)
mm2 = sum(mm2_to_concat)
else:
mm1 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm2 = self.dq_split_linear(hidden_states, self.intermediate_size, self.hidden_size,
self.n_splits_linear, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined] mm1 = self.eltwise_mul(self.swish(mm1), mm2) # type: ignore[attr-defined]
if self.n_splits_down_proj == 1: if self.n_splits_down_proj == 1:
@ -527,23 +452,10 @@ class LLMBaseNNFactory(NNFactory):
) )
else: else:
invalidInputError(seq_len > 0, "seq_len should be provided if use split linear") invalidInputError(seq_len > 0, "seq_len should be provided if use split linear")
if mode == "prefill": hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
down_groupsize = self.intermediate_size // self.n_splits_down_proj self.n_splits_down_proj, wt_dtype=self.dtype,
hidden_states_to_concat = [] scale_factor=(self.group_size == 0),
for i in range(self.n_splits_down_proj): is_prefill=(mode == "prefill"))
sub_mm1 = self.slice(mm1, begin=[0, 0, i * down_groupsize],
end=[1, seq_len, (i + 1) * down_groupsize])
hidden_states_to_concat.append(
self.linear(
sub_mm1, self.hidden_size, down_groupsize, bias=False,
wt_dtype=self.dtype, scale_factor=(self.group_size == 0)
)
)
hidden_states = sum(hidden_states_to_concat)
else:
hidden_states = self.dq_split_linear(mm1, self.hidden_size, self.intermediate_size,
self.n_splits_down_proj, wt_dtype=self.dtype,
scale_factor=(self.group_size == 0))
return hidden_states return hidden_states
def layer_norm(self, hidden_states, layernorm_weight): def layer_norm(self, hidden_states, layernorm_weight):
@ -660,9 +572,11 @@ class LLMBaseNNFactory(NNFactory):
n_splits: int, n_splits: int,
act_dtype: npt.DTypeLike = np.float16, act_dtype: npt.DTypeLike = np.float16,
wt_dtype: npt.DTypeLike = np.float16, wt_dtype: npt.DTypeLike = np.float16,
scale_factor: bool = False): scale_factor: bool = False,
is_prefill: bool = False):
op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels, op = super().dq_split_linear(input_node, n_splits, output_channels, input_channels,
False, act_dtype, wt_dtype, scale_factor) False, act_dtype, wt_dtype, scale_factor,
is_prefill=is_prefill)
self.linear_ops.append(op) self.linear_ops.append(op)
return op return op

View file

@ -827,20 +827,40 @@ def run_prefill(
mlp_layer = curr_layer.mlp mlp_layer = curr_layer.mlp
weights = [] weights = []
if n_splits_linear == 1:
for q, k, v, o, g, u in zip(attn_layer.q_proj_dq_list,
attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list,
attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list,
mlp_layer.up_proj_dq_list):
weights.append((q.weight, q.scale))
weights.append((k.weight, k.scale))
weights.append((v.weight, v.scale))
weights.append((o.weight, o.scale))
weights.append((g.weight, g.scale))
weights.append((u.weight, u.scale))
else:
for layer_list in [attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list,
attn_layer.v_proj_dq_list, attn_layer.o_proj_dq_list,
mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list]:
l_weights = []
scales = []
for l in layer_list:
l_weights.append(l.weight)
scales.append(l.scale)
weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
for q, k, v in zip(attn_layer.q_proj_dq_list, attn_layer.k_proj_dq_list, if n_splits_down_proj == 1:
attn_layer.v_proj_dq_list): for l in mlp_layer.down_proj_dq_list:
weights.append((q.weight, q.scale)) weights.append((l.weight, l.scale))
weights.append((k.weight, k.scale)) else:
weights.append((v.weight, v.scale)) l_weights = []
scales = []
for l in attn_layer.o_proj_dq_list: for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale)) l_weights.append(l.weight)
for g, u in zip(mlp_layer.gate_proj_dq_list, mlp_layer.up_proj_dq_list): scales.append(l.scale)
weights.append((g.weight, g.scale)) weights.append((torch.stack(l_weights, axis=0), torch.stack(scales, axis=0)))
weights.append((u.weight, u.scale))
for l in mlp_layer.down_proj_dq_list:
weights.append((l.weight, l.scale))
cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16) cached_cos = curr_layer.self_attn.rotary_emb.cos_cached.to(torch.float16)
cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16) cached_sin = curr_layer.self_attn.rotary_emb.sin_cached.to(torch.float16)