optimize npu llama2 first token performance (#11451)
This commit is contained in:
parent
4e4ecd5095
commit
029ff15d28
1 changed files with 19 additions and 7 deletions
|
|
@ -219,13 +219,25 @@ def llama_attention_forward(
|
||||||
else:
|
else:
|
||||||
causal_mask = None
|
causal_mask = None
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
if query_states.size(2) == key_states.size(2):
|
||||||
query_states,
|
# first token
|
||||||
key_states,
|
from intel_npu_acceleration_library.functional import scaled_dot_product_attention
|
||||||
value_states,
|
attn_output = scaled_dot_product_attention(
|
||||||
attn_mask=causal_mask,
|
query_states,
|
||||||
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
key_states,
|
||||||
)
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# second+ token
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=causal_mask,
|
||||||
|
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue