optimize npu llama2 first token performance (#11451)
This commit is contained in:
parent
4e4ecd5095
commit
029ff15d28
1 changed files with 19 additions and 7 deletions
|
|
@ -219,13 +219,25 @@ def llama_attention_forward(
|
|||
else:
|
||||
causal_mask = None
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
||||
)
|
||||
if query_states.size(2) == key_states.size(2):
|
||||
# first token
|
||||
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
|
||||
attn_output = scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
||||
)
|
||||
else:
|
||||
# second+ token
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=causal_mask,
|
||||
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue