[LLM] Improve chatglm2/3 rest token performance with long context (#9716)
This commit is contained in:
parent
f2e6abb563
commit
522cf5ed82
2 changed files with 18 additions and 6 deletions
|
|
@ -167,10 +167,17 @@ def baichuan_attention_forward_7b(
|
||||||
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask()
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True,
|
if attention_mask is not None:
|
||||||
enable_mem_efficient=True):
|
if attention_mask.dtype == torch.bool:
|
||||||
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states,
|
attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
|
||||||
attn_mask=attention_mask)
|
|
||||||
|
scaling_factor = 1 / math.sqrt(query_states.size(-1))
|
||||||
|
attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1))
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_output += attention_mask
|
||||||
|
attn_output = torch.softmax(attn_output, -1)
|
||||||
|
attn_output = torch.matmul(attn_output, value_states)
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py
|
# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -370,9 +371,13 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
value_layer,
|
value_layer,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
|
elif attention_mask is None:
|
||||||
|
scaling_factor = 1 / math.sqrt(query_layer.size(-1))
|
||||||
|
attn = torch.matmul(query_layer * scaling_factor, key_layer.transpose(-2, -1))
|
||||||
|
attn = torch.softmax(attn, -1)
|
||||||
|
context_layer = torch.matmul(attn, value_layer)
|
||||||
else:
|
else:
|
||||||
if attention_mask is not None:
|
attention_mask = ~attention_mask
|
||||||
attention_mask = ~attention_mask
|
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
value_layer,
|
value_layer,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue