LLM: fix first token judgement of flash attention (#9841)
* fix flash attention * meet code review * fix
This commit is contained in:
parent
ad4a6b5096
commit
16433dd959
3 changed files with 22 additions and 15 deletions
|
|
@ -367,7 +367,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
|
||||||
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
|
||||||
query_layer = query_layer.permute(1, 2, 0, 3)
|
query_layer = query_layer.permute(1, 2, 0, 3)
|
||||||
L, S = query_layer.shape[2], key_layer.shape[2]
|
L, S = query_layer.shape[2], key_layer.shape[2]
|
||||||
if attention_mask is None and (use_flash_attention(query_layer) or
|
if attention_mask is None and (use_flash_attention(query_layer, key_layer) or
|
||||||
L == S and query_layer.device.type == "cpu"):
|
L == S and query_layer.device.type == "cpu"):
|
||||||
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
|
||||||
key_layer,
|
key_layer,
|
||||||
|
|
|
||||||
|
|
@ -139,14 +139,6 @@ def llama_attention_forward_4_31(
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
if not self.training and not hidden_states.requires_grad:
|
|
||||||
fsdp_flag = use_flash_attention(hidden_states)
|
|
||||||
else:
|
|
||||||
fsdp_flag = False
|
|
||||||
if fsdp_flag:
|
|
||||||
attention_dtype = torch.float16 # use fp16 for flash attention
|
|
||||||
else:
|
|
||||||
attention_dtype = original_dtype
|
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
|
||||||
|
|
@ -259,6 +251,15 @@ def llama_attention_forward_4_31(
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
if not self.training and not hidden_states.requires_grad:
|
||||||
|
fsdp_flag = use_flash_attention(query_states, key_states)
|
||||||
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
|
|
|
||||||
|
|
@ -123,18 +123,24 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
|
||||||
(past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
|
(past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
|
||||||
|
|
||||||
|
|
||||||
def use_flash_attention(query):
|
def use_flash_attention(query, key):
|
||||||
if query.dim() == 3:
|
# here we support query's shape is always [batch_size, head_num, q_len, head_dim],
|
||||||
bsz, q_len, _ = query.size()
|
# key's shape is always [batch_size, head_num, k_len, head_dim]
|
||||||
elif query.dim() == 4:
|
invalidInputError(query.dim() == 4,
|
||||||
bsz, _, q_len, _ = query.size()
|
"Here query input of use_flash_attention should be [batch_size, "
|
||||||
|
"head_num, q_len, head_dim]")
|
||||||
|
invalidInputError(key.dim() == 4,
|
||||||
|
"Here key input of use_flash_attention should be [batch_size, "
|
||||||
|
"head_num, k_len, head_dim]")
|
||||||
|
bsz, _, q_len, _ = query.size()
|
||||||
|
k_len = key.size()[2]
|
||||||
# check whether ipex flash attention can be used
|
# check whether ipex flash attention can be used
|
||||||
if bsz > 1:
|
if bsz > 1:
|
||||||
# only use flash attention for batch_size = 1 now
|
# only use flash attention for batch_size = 1 now
|
||||||
# as flash attention doesn't support attn_mask in ipex 2.1,
|
# as flash attention doesn't support attn_mask in ipex 2.1,
|
||||||
# so it will cause output error for padded batch input
|
# so it will cause output error for padded batch input
|
||||||
return False
|
return False
|
||||||
if q_len == 1:
|
if q_len != k_len:
|
||||||
# now only use flash attention for first token
|
# now only use flash attention for first token
|
||||||
# as it seems have no performance benifit for rest token now
|
# as it seems have no performance benifit for rest token now
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue