LLM: add flash attention support for llama (#9518)
* add initial flash attention for llama * accelerate fp32 first token by changing to fp16 in advance * support fp32
This commit is contained in:
		
							parent
							
								
									bf579507c2
								
							
						
					
					
						commit
						b63aae8a8e
					
				
					 1 changed files with 68 additions and 24 deletions
				
			
		| 
						 | 
					@ -106,6 +106,16 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
					) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 | 
				
			||||||
    bsz, q_len, _ = hidden_states.size()
 | 
					    bsz, q_len, _ = hidden_states.size()
 | 
				
			||||||
    device = hidden_states.device
 | 
					    device = hidden_states.device
 | 
				
			||||||
 | 
					    # for flash attention
 | 
				
			||||||
 | 
					    original_dtype = hidden_states.dtype
 | 
				
			||||||
 | 
					    if not self.training and not hidden_states.requires_grad:
 | 
				
			||||||
 | 
					        fsdp_flag = check_flash_attention_available(hidden_states)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        fsdp_flag = False
 | 
				
			||||||
 | 
					    if fsdp_flag and q_len > 1:
 | 
				
			||||||
 | 
					        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        attention_dtype = original_dtype
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.config.pretraining_tp > 1:
 | 
					    if self.config.pretraining_tp > 1:
 | 
				
			||||||
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 | 
					        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 | 
				
			||||||
| 
						 | 
					@ -194,31 +204,23 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
					    # repeat k/v heads if n_kv_heads < n_heads
 | 
				
			||||||
    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
					    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
				
			||||||
                                                                     dtype=hidden_states.dtype)
 | 
					                                                                     dtype=attention_dtype)
 | 
				
			||||||
    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
					    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
				
			||||||
                                                                         dtype=hidden_states.dtype)
 | 
					                                                                         dtype=attention_dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = torch.matmul(query_states,
 | 
					    if fsdp_flag and q_len > 1:
 | 
				
			||||||
                                key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
					        # now only use flash attention for first token
 | 
				
			||||||
 | 
					        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
 | 
				
			||||||
    attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len)
 | 
					                                                     key_states,
 | 
				
			||||||
    if attn_weights.size() != attn_weights_size:
 | 
					                                                     value_states,
 | 
				
			||||||
        invalidInputError(False,
 | 
					                                                     is_causal=True)
 | 
				
			||||||
                          f"Attention weights should be of size {attn_weights_size}, "
 | 
					        attn_weights = None
 | 
				
			||||||
                          f"but is {attn_weights.size()}")
 | 
					    else:
 | 
				
			||||||
 | 
					        # otherwise, use native attention
 | 
				
			||||||
    if attention_mask is not None:
 | 
					        attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
 | 
				
			||||||
        attn_mask_size = (bsz, 1, q_len, kv_seq_len)
 | 
					                                               attention_mask,
 | 
				
			||||||
        if attention_mask.size() != attn_mask_size:
 | 
					                                               bsz, q_len, kv_seq_len,
 | 
				
			||||||
            invalidInputError(False,
 | 
					                                               self.head_dim, self.num_heads)
 | 
				
			||||||
                              f"Attention mask should be of size {attn_mask_size}, "
 | 
					 | 
				
			||||||
                              f"but is {attention_mask.size()}")
 | 
					 | 
				
			||||||
        attn_weights = attn_weights + attention_mask
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # upcast attention to fp32
 | 
					 | 
				
			||||||
    attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
					 | 
				
			||||||
                                         dtype=torch.float32).to(query_states.dtype)
 | 
					 | 
				
			||||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
					    attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
 | 
				
			||||||
    if attn_output.size() != attn_output_size:
 | 
					    if attn_output.size() != attn_output_size:
 | 
				
			||||||
| 
						 | 
					@ -241,4 +243,46 @@ def llama_attention_forward_4_31(
 | 
				
			||||||
    if not output_attentions:
 | 
					    if not output_attentions:
 | 
				
			||||||
        attn_weights = None
 | 
					        attn_weights = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return attn_output, attn_weights, past_key_value
 | 
					    return attn_output.to(original_dtype), attn_weights, past_key_value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def check_flash_attention_available(query):
 | 
				
			||||||
 | 
					    # check whether ipex flash attention can be used
 | 
				
			||||||
 | 
					    if query.device.type != "xpu":
 | 
				
			||||||
 | 
					        # ipex flash attention only support for xpu
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    ipex_version = get_ipex_version()
 | 
				
			||||||
 | 
					    if ipex_version <= "2.0.110+xpu":
 | 
				
			||||||
 | 
					        # ipex flash attention is supported from ipex 2.1
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    if not torch.xpu.has_xetla():
 | 
				
			||||||
 | 
					        # ipex flash attention is only supported for xetla
 | 
				
			||||||
 | 
					        # may update this later
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def native_sdp(query, key, value, attention_mask,
 | 
				
			||||||
 | 
					               bsz, q_len, kv_seq_len, head_dim, num_heads):
 | 
				
			||||||
 | 
					    attn_weights = torch.matmul(query,
 | 
				
			||||||
 | 
					                                key.transpose(2, 3)) / math.sqrt(head_dim)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)
 | 
				
			||||||
 | 
					    if attn_weights.size() != attn_weights_size:
 | 
				
			||||||
 | 
					        invalidInputError(False,
 | 
				
			||||||
 | 
					                          f"Attention weights should be of size {attn_weights_size}, "
 | 
				
			||||||
 | 
					                          f"but is {attn_weights.size()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if attention_mask is not None:
 | 
				
			||||||
 | 
					        attn_mask_size = (bsz, 1, q_len, kv_seq_len)
 | 
				
			||||||
 | 
					        if attention_mask.size() != attn_mask_size:
 | 
				
			||||||
 | 
					            invalidInputError(False,
 | 
				
			||||||
 | 
					                              f"Attention mask should be of size {attn_mask_size}, "
 | 
				
			||||||
 | 
					                              f"but is {attention_mask.size()}")
 | 
				
			||||||
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # upcast attention to fp32
 | 
				
			||||||
 | 
					    attn_weights = nn.functional.softmax(attn_weights, dim=-1,
 | 
				
			||||||
 | 
					                                         dtype=torch.float32).to(value.dtype)
 | 
				
			||||||
 | 
					    attn_output = torch.matmul(attn_weights, value)
 | 
				
			||||||
 | 
					    return attn_output, attn_weights
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue