optimize npu llama2 first token performance (#11451)

This commit is contained in:
Yishuo Wang 2024-06-27 17:37:33 +08:00 committed by GitHub
parent 4e4ecd5095
commit 029ff15d28
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -219,12 +219,24 @@ def llama_attention_forward(
else:
causal_mask = None
if query_states.size(2) == key_states.size(2):
# first token
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
attn_output = scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
is_causal=self.is_causal and causal_mask is None and q_len > 1,
)
else:
# second+ token
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
is_causal=self.is_causal and causal_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()