LLM: fix first token judgement of flash attention (#9841)
* fix flash attention * meet code review * fix
This commit is contained in:
		
							parent
							
								
									ad4a6b5096
								
							
						
					
					
						commit
						16433dd959
					
				
					 3 changed files with 22 additions and 15 deletions
				
			
		| 
						 | 
				
			
			@ -367,7 +367,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio
 | 
			
		|||
    if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1):
 | 
			
		||||
        query_layer = query_layer.permute(1, 2, 0, 3)
 | 
			
		||||
        L, S = query_layer.shape[2], key_layer.shape[2]
 | 
			
		||||
        if attention_mask is None and (use_flash_attention(query_layer) or
 | 
			
		||||
        if attention_mask is None and (use_flash_attention(query_layer, key_layer) or
 | 
			
		||||
                                       L == S and query_layer.device.type == "cpu"):
 | 
			
		||||
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer,
 | 
			
		||||
                                                                             key_layer,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -139,14 +139,6 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    device = hidden_states.device
 | 
			
		||||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(hidden_states)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    else:
 | 
			
		||||
        attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len)
 | 
			
		||||
| 
						 | 
				
			
			@ -259,6 +251,15 @@ def llama_attention_forward_4_31(
 | 
			
		|||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    else:
 | 
			
		||||
        attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                     dtype=attention_dtype)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -123,18 +123,24 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1):
 | 
			
		|||
        (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def use_flash_attention(query):
 | 
			
		||||
    if query.dim() == 3:
 | 
			
		||||
        bsz, q_len, _ = query.size()
 | 
			
		||||
    elif query.dim() == 4:
 | 
			
		||||
        bsz, _, q_len, _ = query.size()
 | 
			
		||||
def use_flash_attention(query, key):
 | 
			
		||||
    # here we support query's shape is always [batch_size, head_num, q_len, head_dim],
 | 
			
		||||
    # key's shape is always [batch_size, head_num, k_len, head_dim]
 | 
			
		||||
    invalidInputError(query.dim() == 4,
 | 
			
		||||
                      "Here query input of use_flash_attention should be [batch_size, "
 | 
			
		||||
                      "head_num, q_len, head_dim]")
 | 
			
		||||
    invalidInputError(key.dim() == 4,
 | 
			
		||||
                      "Here key input of use_flash_attention should be [batch_size, "
 | 
			
		||||
                      "head_num, k_len, head_dim]")
 | 
			
		||||
    bsz, _, q_len, _ = query.size()
 | 
			
		||||
    k_len = key.size()[2]
 | 
			
		||||
    # check whether ipex flash attention can be used
 | 
			
		||||
    if bsz > 1:
 | 
			
		||||
        # only use flash attention for batch_size = 1 now
 | 
			
		||||
        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
			
		||||
        # so it will cause output error for padded batch input
 | 
			
		||||
        return False
 | 
			
		||||
    if q_len == 1:
 | 
			
		||||
    if q_len != k_len:
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
        # as it seems have no performance benifit for rest token now
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue