LLM: fix wrong batch output caused by flash attention (#9780)
* fix * meet code review * move batch size check to the beginning * move qlen check inside function * meet code review
This commit is contained in:
		
							parent
							
								
									66e286a73d
								
							
						
					
					
						commit
						11d883301b
					
				
					 1 changed files with 15 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -137,10 +137,10 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = check_flash_attention_available(hidden_states)
 | 
			
		||||
        fsdp_flag = use_flash_attention(hidden_states)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag and q_len > 1:
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    else:
 | 
			
		||||
        attention_dtype = original_dtype
 | 
			
		||||
| 
						 | 
				
			
			@ -261,8 +261,7 @@ def llama_attention_forward_4_31(
 | 
			
		|||
    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                         dtype=attention_dtype)
 | 
			
		||||
 | 
			
		||||
    if fsdp_flag and q_len > 1:
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
 | 
			
		||||
                                                     key_states,
 | 
			
		||||
                                                     value_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -325,7 +324,7 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
    original_dtype = hidden_states.dtype
 | 
			
		||||
    # TODO: consider this later - flash attention
 | 
			
		||||
    # if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
    #     fsdp_flag = check_flash_attention_available(hidden_states)
 | 
			
		||||
    #     fsdp_flag = use_flash_attention(hidden_states)
 | 
			
		||||
    # else:
 | 
			
		||||
    #     fsdp_flag = False
 | 
			
		||||
    # if fsdp_flag and q_len > 1:
 | 
			
		||||
| 
						 | 
				
			
			@ -506,8 +505,18 @@ def llama_attention_selective_batching_forward_4_31(
 | 
			
		|||
    return attn_output.to(original_dtype), attn_weights, updated_past_key_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_flash_attention_available(query):
 | 
			
		||||
def use_flash_attention(query):
 | 
			
		||||
    bsz, q_len, _ = query.size()
 | 
			
		||||
    # check whether ipex flash attention can be used
 | 
			
		||||
    if bsz > 1:
 | 
			
		||||
        # only use flash attention for batch_size = 1 now
 | 
			
		||||
        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
			
		||||
        # so it will cause output error for padded batch input
 | 
			
		||||
        return False
 | 
			
		||||
    if q_len == 1:
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
        # as it seems have no performance benifit for rest token now
 | 
			
		||||
        return False
 | 
			
		||||
    if query.device.type != "xpu":
 | 
			
		||||
        # ipex flash attention only support for xpu
 | 
			
		||||
        return False
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue