optimize npu llama long context performance (#11478)

This commit is contained in:
Yishuo Wang 2024-07-01 16:49:23 +08:00 committed by GitHub
parent 913e750b01
commit ec3a912ab6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -34,6 +34,7 @@
from typing import Optional, Tuple, List, Union
import math
import torch
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
@ -230,14 +231,14 @@ def llama_attention_forward(
is_causal=self.is_causal and causal_mask is None and q_len > 1,
)
else:
# second+ token
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
is_causal=self.is_causal and causal_mask is None and q_len > 1,
)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if causal_mask is not None:
attn_weights = attn_weights + causal_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()