optimize npu llama long context performance (#11478)

This commit is contained in:
Yishuo Wang 2024-07-01 16:49:23 +08:00 committed by GitHub
parent 913e750b01
commit ec3a912ab6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -34,6 +34,7 @@
from typing import Optional, Tuple, List, Union from typing import Optional, Tuple, List, Union
import math
import torch import torch
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
@ -230,14 +231,14 @@ def llama_attention_forward(
is_causal=self.is_causal and causal_mask is None and q_len > 1, is_causal=self.is_causal and causal_mask is None and q_len > 1,
) )
else: else:
# second+ token attn_weights = torch.matmul(query_states,
attn_output = torch.nn.functional.scaled_dot_product_attention( key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
query_states, if causal_mask is not None:
key_states, attn_weights = attn_weights + causal_mask
value_states,
attn_mask=causal_mask, attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
is_causal=self.is_causal and causal_mask is None and q_len > 1, dtype=torch.float32).to(value_states.dtype)
) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()