LLM: relax batch check of flash atttention by double check attention mask (#10270)
* relax batch check * fix * fix style
This commit is contained in:
		
							parent
							
								
									07f36fbfcc
								
							
						
					
					
						commit
						4b08bc1417
					
				
					 2 changed files with 20 additions and 9 deletions
				
			
		| 
						 | 
				
			
			@ -349,7 +349,7 @@ def llama_attention_forward_4_31_quantized(
 | 
			
		|||
                                                            cos, sin, position_ids, "llama")
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states)
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
| 
						 | 
				
			
			@ -629,7 +629,7 @@ def llama_attention_forward_4_31_original(
 | 
			
		|||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states)
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
| 
						 | 
				
			
			@ -1068,7 +1068,7 @@ def llama_attention_forward_4_36(
 | 
			
		|||
                past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states)
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -236,7 +236,7 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
 | 
			
		|||
        (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_flash_attention(query, key):
 | 
			
		||||
def use_flash_attention(query, key, attention_mask=None):
 | 
			
		||||
    # here we support query's shape is always [batch_size, head_num, q_len, head_dim],
 | 
			
		||||
    # key's shape is always [batch_size, head_num, k_len, head_dim]
 | 
			
		||||
    invalidInputError(query.dim() == 4,
 | 
			
		||||
| 
						 | 
				
			
			@ -248,11 +248,6 @@ def use_flash_attention(query, key):
 | 
			
		|||
    bsz, _, q_len, _ = query.size()
 | 
			
		||||
    k_len = key.size()[2]
 | 
			
		||||
    # check whether ipex flash attention can be used
 | 
			
		||||
    if bsz > 1:
 | 
			
		||||
        # only use flash attention for batch_size = 1 now
 | 
			
		||||
        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
			
		||||
        # so it will cause output error for padded batch input
 | 
			
		||||
        return False
 | 
			
		||||
    if q_len != k_len:
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
        # as it seems have no performance benifit for rest token now
 | 
			
		||||
| 
						 | 
				
			
			@ -271,6 +266,22 @@ def use_flash_attention(query, key):
 | 
			
		|||
    if query.dtype not in [torch.float32, torch.float16]:
 | 
			
		||||
        # only use flash attention for fp32/fp16 input
 | 
			
		||||
        return False
 | 
			
		||||
    if bsz > 1:
 | 
			
		||||
        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
			
		||||
        # so it will cause output error for padded batch input
 | 
			
		||||
        if attention_mask is None:
 | 
			
		||||
            return True
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO: below logic may change for different model
 | 
			
		||||
            # attention mask shape : [bsz, 1, q_len, k_len]
 | 
			
		||||
            if attention_mask[0].squeeze()[0, 0].item() != 0:
 | 
			
		||||
                # first batch contains padding
 | 
			
		||||
                # otherwise we suppose it should be a upper triangular matrix
 | 
			
		||||
                # at the same time, the diagonal is also 0
 | 
			
		||||
                return False
 | 
			
		||||
            elif not attention_mask.equal(attention_mask[0].repeat(bsz, 1, 1, 1)):
 | 
			
		||||
                # check whether mask of every batch is the same
 | 
			
		||||
                return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue