optimize npu llama long context performance (#11478)
This commit is contained in:
parent
913e750b01
commit
ec3a912ab6
1 changed files with 9 additions and 8 deletions
|
|
@ -34,6 +34,7 @@
|
||||||
|
|
||||||
from typing import Optional, Tuple, List, Union
|
from typing import Optional, Tuple, List, Union
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
|
|
@ -230,14 +231,14 @@ def llama_attention_forward(
|
||||||
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# second+ token
|
attn_weights = torch.matmul(query_states,
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
query_states,
|
if causal_mask is not None:
|
||||||
key_states,
|
attn_weights = attn_weights + causal_mask
|
||||||
value_states,
|
|
||||||
attn_mask=causal_mask,
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
|
||||||
is_causal=self.is_causal and causal_mask is None and q_len > 1,
|
dtype=torch.float32).to(value_states.dtype)
|
||||||
)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue