[NPU] Llama2 prefill use ov sdp (#12310)
* prefill use sdp * add param * update * fix style * fix style * meet comments
This commit is contained in:
parent
eda764909c
commit
05c5d0267a
2 changed files with 46 additions and 19 deletions
|
|
@ -110,12 +110,19 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
# define input, the order self.parameter matters
|
# define input, the order self.parameter matters
|
||||||
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
input = self.create_input_op((self.batch_size, self.seq_len, self.hidden_size))
|
||||||
|
|
||||||
|
# llama2 use ov sdp, other models need to test
|
||||||
|
use_prefill_sdp = self.intermediate_size == 11008
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if mode == "decode":
|
if mode == "decode":
|
||||||
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
|
attention_mask = self.create_input_op((self.batch_size, 1, 1, self.max_seq_len + 1),
|
||||||
dtype=np.int64)
|
dtype=np.int64)
|
||||||
else:
|
else:
|
||||||
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len, self.seq_len),
|
if use_prefill_sdp:
|
||||||
|
attention_mask = None
|
||||||
|
else:
|
||||||
|
attention_mask = self.create_input_op((self.batch_size, 1, self.seq_len,
|
||||||
|
self.seq_len),
|
||||||
dtype=np.int64)
|
dtype=np.int64)
|
||||||
|
|
||||||
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
position_ids = self.create_input_op((self.batch_size, self.seq_len), dtype=np.int64)
|
||||||
|
|
@ -177,6 +184,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
post_attention_layernorm_weight=post_attn_layernorm_weights[i],
|
||||||
past_key=past_keys[i],
|
past_key=past_keys[i],
|
||||||
past_value=past_values[i],
|
past_value=past_values[i],
|
||||||
|
use_prefill_sdp=use_prefill_sdp,
|
||||||
)
|
)
|
||||||
curr_key_values.append((new_key_states, new_value_states))
|
curr_key_values.append((new_key_states, new_value_states))
|
||||||
|
|
||||||
|
|
@ -202,6 +210,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
post_attention_layernorm_weight,
|
post_attention_layernorm_weight,
|
||||||
past_key=None,
|
past_key=None,
|
||||||
past_value=None,
|
past_value=None,
|
||||||
|
use_prefill_sdp=False,
|
||||||
):
|
):
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
@ -220,6 +229,7 @@ class LowBitLlamaMultiDecoderlayer(LLMBaseNNFactory):
|
||||||
num_key_value_heads=self.num_key_value_heads,
|
num_key_value_heads=self.num_key_value_heads,
|
||||||
head_dim=self.head_dim,
|
head_dim=self.head_dim,
|
||||||
seq_len=self.seq_len,
|
seq_len=self.seq_len,
|
||||||
|
use_prefill_sdp=use_prefill_sdp,
|
||||||
)
|
)
|
||||||
hidden_states = self.eltwise_add(residual, attn_output)
|
hidden_states = self.eltwise_add(residual, attn_output)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
@ -427,6 +437,7 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.layer_norm_0 = layer_norm_0
|
self.layer_norm_0 = layer_norm_0
|
||||||
self.layer_norm_1 = layer_norm_1
|
self.layer_norm_1 = layer_norm_1
|
||||||
|
self.use_prefill_sdp = intermediate_size == 11008
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -451,6 +462,10 @@ class FusedLlamaLowBitDecoderlayer(torch.nn.Module):
|
||||||
seq_len = hidden_states.shape[1]
|
seq_len = hidden_states.shape[1]
|
||||||
|
|
||||||
backend_cls = self.backend_cls_prefill
|
backend_cls = self.backend_cls_prefill
|
||||||
|
if self.use_prefill_sdp:
|
||||||
|
inputs = (hidden_states.to(torch.float16),
|
||||||
|
position_ids.to(torch.int64))
|
||||||
|
else:
|
||||||
inputs = (hidden_states.to(torch.float16),
|
inputs = (hidden_states.to(torch.float16),
|
||||||
attention_mask.to(torch.int64),
|
attention_mask.to(torch.int64),
|
||||||
position_ids.to(torch.int64))
|
position_ids.to(torch.int64))
|
||||||
|
|
|
||||||
|
|
@ -135,10 +135,10 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
seq_len,
|
seq_len,
|
||||||
q_bias=None,
|
q_bias=None,
|
||||||
k_bias=None,
|
k_bias=None,
|
||||||
v_bias=None):
|
v_bias=None,
|
||||||
|
use_prefill_sdp=False):
|
||||||
hidden_size = num_heads * head_dim
|
hidden_size = num_heads * head_dim
|
||||||
num_key_value_groups = num_heads // num_key_value_heads
|
num_key_value_groups = num_heads // num_key_value_heads
|
||||||
groupsize = hidden_size // self.n_splits_linear
|
|
||||||
if self.n_splits_linear == 1:
|
if self.n_splits_linear == 1:
|
||||||
query_states = self.linear(
|
query_states = self.linear(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|
@ -200,8 +200,13 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
|
|
||||||
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
query_states = self.transpose(query_states, [0, 2, 1, 3])
|
||||||
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
key_states = self.transpose(key_states, [0, 2, 1, 3])
|
||||||
|
use_ov_sdp = (mode == "prefill") and use_prefill_sdp
|
||||||
if self.transpose_value:
|
if self.transpose_value:
|
||||||
value_states = self.transpose(value_states, [0, 2, 3, 1])
|
new_value_states = self.transpose(value_states, [0, 2, 3, 1])
|
||||||
|
if use_ov_sdp:
|
||||||
|
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||||
|
else:
|
||||||
|
value_states = new_value_states
|
||||||
else:
|
else:
|
||||||
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
value_states = self.transpose(value_states, [0, 2, 1, 3])
|
||||||
|
|
||||||
|
|
@ -216,7 +221,6 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
)
|
)
|
||||||
new_key_states = key_states
|
new_key_states = key_states
|
||||||
new_value_states = value_states
|
|
||||||
|
|
||||||
if mode == "decode":
|
if mode == "decode":
|
||||||
key_states = self.concat(past_key, key_states, axis=-2)
|
key_states = self.concat(past_key, key_states, axis=-2)
|
||||||
|
|
@ -238,7 +242,15 @@ class LLMBaseNNFactory(NNFactory):
|
||||||
num_key_value_heads=num_key_value_heads,
|
num_key_value_heads=num_key_value_heads,
|
||||||
kv_seq_len=kv_seq_len,
|
kv_seq_len=kv_seq_len,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
transpose=self.transpose_value)
|
transpose=(self.transpose_value and (not use_ov_sdp)))
|
||||||
|
if use_ov_sdp:
|
||||||
|
value_states = self.convert_to_fp32(value_states)
|
||||||
|
key_states = self.convert_to_fp32(key_states)
|
||||||
|
query_states = self.convert_to_fp32(query_states)
|
||||||
|
attn_output = self.scaled_dot_product_attention(
|
||||||
|
query_states, key_states, value_states, None, True)
|
||||||
|
attn_output = self.convert_to_fp16(attn_output)
|
||||||
|
else:
|
||||||
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
attn_weight = self.matmul(query_states, key_states, False, True) / (
|
||||||
math.sqrt(head_dim)
|
math.sqrt(head_dim)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue