LLM: relax batch check of flash atttention by double check attention mask (#10270)
* relax batch check * fix * fix style
This commit is contained in:
parent
07f36fbfcc
commit
4b08bc1417
2 changed files with 20 additions and 9 deletions
|
|
@ -349,7 +349,7 @@ def llama_attention_forward_4_31_quantized(
|
||||||
cos, sin, position_ids, "llama")
|
cos, sin, position_ids, "llama")
|
||||||
|
|
||||||
if not self.training and not hidden_states.requires_grad:
|
if not self.training and not hidden_states.requires_grad:
|
||||||
fsdp_flag = use_flash_attention(query_states, key_states)
|
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
fsdp_flag = False
|
fsdp_flag = False
|
||||||
if fsdp_flag:
|
if fsdp_flag:
|
||||||
|
|
@ -629,7 +629,7 @@ def llama_attention_forward_4_31_original(
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
if not self.training and not hidden_states.requires_grad:
|
if not self.training and not hidden_states.requires_grad:
|
||||||
fsdp_flag = use_flash_attention(query_states, key_states)
|
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
fsdp_flag = False
|
fsdp_flag = False
|
||||||
if fsdp_flag:
|
if fsdp_flag:
|
||||||
|
|
@ -1068,7 +1068,7 @@ def llama_attention_forward_4_36(
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
if not self.training and not hidden_states.requires_grad:
|
if not self.training and not hidden_states.requires_grad:
|
||||||
fsdp_flag = use_flash_attention(query_states, key_states)
|
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
fsdp_flag = False
|
fsdp_flag = False
|
||||||
if fsdp_flag:
|
if fsdp_flag:
|
||||||
|
|
|
||||||
|
|
@ -236,7 +236,7 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
|
||||||
(past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
|
(past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
|
||||||
|
|
||||||
|
|
||||||
def use_flash_attention(query, key):
|
def use_flash_attention(query, key, attention_mask=None):
|
||||||
# here we support query's shape is always [batch_size, head_num, q_len, head_dim],
|
# here we support query's shape is always [batch_size, head_num, q_len, head_dim],
|
||||||
# key's shape is always [batch_size, head_num, k_len, head_dim]
|
# key's shape is always [batch_size, head_num, k_len, head_dim]
|
||||||
invalidInputError(query.dim() == 4,
|
invalidInputError(query.dim() == 4,
|
||||||
|
|
@ -248,11 +248,6 @@ def use_flash_attention(query, key):
|
||||||
bsz, _, q_len, _ = query.size()
|
bsz, _, q_len, _ = query.size()
|
||||||
k_len = key.size()[2]
|
k_len = key.size()[2]
|
||||||
# check whether ipex flash attention can be used
|
# check whether ipex flash attention can be used
|
||||||
if bsz > 1:
|
|
||||||
# only use flash attention for batch_size = 1 now
|
|
||||||
# as flash attention doesn't support attn_mask in ipex 2.1,
|
|
||||||
# so it will cause output error for padded batch input
|
|
||||||
return False
|
|
||||||
if q_len != k_len:
|
if q_len != k_len:
|
||||||
# now only use flash attention for first token
|
# now only use flash attention for first token
|
||||||
# as it seems have no performance benifit for rest token now
|
# as it seems have no performance benifit for rest token now
|
||||||
|
|
@ -271,6 +266,22 @@ def use_flash_attention(query, key):
|
||||||
if query.dtype not in [torch.float32, torch.float16]:
|
if query.dtype not in [torch.float32, torch.float16]:
|
||||||
# only use flash attention for fp32/fp16 input
|
# only use flash attention for fp32/fp16 input
|
||||||
return False
|
return False
|
||||||
|
if bsz > 1:
|
||||||
|
# as flash attention doesn't support attn_mask in ipex 2.1,
|
||||||
|
# so it will cause output error for padded batch input
|
||||||
|
if attention_mask is None:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# TODO: below logic may change for different model
|
||||||
|
# attention mask shape : [bsz, 1, q_len, k_len]
|
||||||
|
if attention_mask[0].squeeze()[0, 0].item() != 0:
|
||||||
|
# first batch contains padding
|
||||||
|
# otherwise we suppose it should be a upper triangular matrix
|
||||||
|
# at the same time, the diagonal is also 0
|
||||||
|
return False
|
||||||
|
elif not attention_mask.equal(attention_mask[0].repeat(bsz, 1, 1, 1)):
|
||||||
|
# check whether mask of every batch is the same
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue