LLM: fix wrong batch output caused by flash attention (#9780)
* fix * meet code review * move batch size check to the beginning * move qlen check inside function * meet code review
This commit is contained in:
		
							parent
							
								
									66e286a73d
								
							
						
					
					
						commit
						11d883301b
					
				
					 1 changed files with 15 additions and 6 deletions
				
			
		| 
						 | 
					@ -137,10 +137,10 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    # for flash attention
 | 
					    # for flash attention
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
    if not self.training and not hidden_states.requires_grad:
 | 
					    if not self.training and not hidden_states.requires_grad:
 | 
				
			||||||
        fsdp_flag = check_flash_attention_available(hidden_states)
 | 
					        fsdp_flag = use_flash_attention(hidden_states)
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        fsdp_flag = False
 | 
					        fsdp_flag = False
 | 
				
			||||||
    if fsdp_flag and q_len > 1:
 | 
					    if fsdp_flag:
 | 
				
			||||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
					        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        attention_dtype = original_dtype
 | 
					        attention_dtype = original_dtype
 | 
				
			||||||
| 
						 | 
					@ -261,8 +261,7 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
				
			||||||
                                                                         dtype=attention_dtype)
 | 
					                                                                         dtype=attention_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if fsdp_flag and q_len > 1:
 | 
					    if fsdp_flag:
 | 
				
			||||||
        # now only use flash attention for first token
 | 
					 | 
				
			||||||
        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
 | 
					        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
 | 
				
			||||||
                                                     key_states,
 | 
					                                                     key_states,
 | 
				
			||||||
                                                     value_states,
 | 
					                                                     value_states,
 | 
				
			||||||
| 
						 | 
					@ -325,7 +324,7 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
    original_dtype = hidden_states.dtype
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
    # TODO: consider this later - flash attention
 | 
					    # TODO: consider this later - flash attention
 | 
				
			||||||
    # if not self.training and not hidden_states.requires_grad:
 | 
					    # if not self.training and not hidden_states.requires_grad:
 | 
				
			||||||
    #     fsdp_flag = check_flash_attention_available(hidden_states)
 | 
					    #     fsdp_flag = use_flash_attention(hidden_states)
 | 
				
			||||||
    # else:
 | 
					    # else:
 | 
				
			||||||
    #     fsdp_flag = False
 | 
					    #     fsdp_flag = False
 | 
				
			||||||
    # if fsdp_flag and q_len > 1:
 | 
					    # if fsdp_flag and q_len > 1:
 | 
				
			||||||
| 
						 | 
					@ -506,8 +505,18 @@ def llama_attention_selective_batching_forward_4_31(
 | 
				
			||||||
    return attn_output.to(original_dtype), attn_weights, updated_past_key_values
 | 
					    return attn_output.to(original_dtype), attn_weights, updated_past_key_values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def check_flash_attention_available(query):
 | 
					def use_flash_attention(query):
 | 
				
			||||||
 | 
					    bsz, q_len, _ = query.size()
 | 
				
			||||||
    # check whether ipex flash attention can be used
 | 
					    # check whether ipex flash attention can be used
 | 
				
			||||||
 | 
					    if bsz > 1:
 | 
				
			||||||
 | 
					        # only use flash attention for batch_size = 1 now
 | 
				
			||||||
 | 
					        # as flash attention doesn't support attn_mask in ipex 2.1,
 | 
				
			||||||
 | 
					        # so it will cause output error for padded batch input
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    if q_len == 1:
 | 
				
			||||||
 | 
					        # now only use flash attention for first token
 | 
				
			||||||
 | 
					        # as it seems have no performance benifit for rest token now
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
    if query.device.type != "xpu":
 | 
					    if query.device.type != "xpu":
 | 
				
			||||||
        # ipex flash attention only support for xpu
 | 
					        # ipex flash attention only support for xpu
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue