Qwen2 fp16 sdp (#10427)
* qwen2 sdp and refine * update * update * fix style * remove use_flash_attention
This commit is contained in:
		
							parent
							
								
									1315150e64
								
							
						
					
					
						commit
						24473e331a
					
				
					 3 changed files with 52 additions and 51 deletions
				
			
		| 
						 | 
				
			
			@ -604,20 +604,19 @@ def llama_attention_forward_4_31_original(
 | 
			
		|||
 | 
			
		||||
    past_key_value = (key_states, value_states) if use_cache else None
 | 
			
		||||
 | 
			
		||||
    fsdp_flag = not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
        use_flash_attention(query_states, key_states, attention_mask)
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     key_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     value_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
| 
						 | 
				
			
			@ -1249,29 +1248,20 @@ def llama_attention_forward_4_36_original(
 | 
			
		|||
                past_key_value.key_cache[self.layer_idx] = key_states
 | 
			
		||||
                past_key_value.value_cache[self.layer_idx] = value_states
 | 
			
		||||
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad:
 | 
			
		||||
        fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
 | 
			
		||||
    else:
 | 
			
		||||
        fsdp_flag = False
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
        attention_dtype = torch.float16  # use fp16 for flash attention
 | 
			
		||||
    else:
 | 
			
		||||
        attention_dtype = original_dtype
 | 
			
		||||
 | 
			
		||||
    # repeat k/v heads if n_kv_heads < n_heads
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                     dtype=attention_dtype)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
 | 
			
		||||
                                                                         dtype=attention_dtype)
 | 
			
		||||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query_states, key_states, attention_mask):
 | 
			
		||||
        # now only use flash attention for first token
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
 | 
			
		||||
                                                     key_states,
 | 
			
		||||
                                                     value_states,
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     key_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     value_states.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     is_causal=True)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -273,10 +273,8 @@ def qwen_attention_forward_original(
 | 
			
		|||
    if not decoding_fast_path:
 | 
			
		||||
        query = query.transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
    fsdp_flag = not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
        use_flash_attention(query, key)
 | 
			
		||||
 | 
			
		||||
    if fsdp_flag:
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_flash_attention(query, key):
 | 
			
		||||
        attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     key.to(device, dtype=torch.float16),
 | 
			
		||||
                                                     value.to(device, dtype=torch.float16),
 | 
			
		||||
| 
						 | 
				
			
			@ -284,7 +282,8 @@ def qwen_attention_forward_original(
 | 
			
		|||
        attn_output = attn_output.view(query.shape)
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    elif use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
    elif not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query,
 | 
			
		||||
                                                    key,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -43,6 +43,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.transformers.models.llama import repeat_kv
 | 
			
		||||
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
 | 
			
		||||
| 
						 | 
				
			
			@ -51,6 +52,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36
 | 
			
		|||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
 | 
			
		||||
from bigdl.llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
from bigdl.llm.utils.common import invalidInputError
 | 
			
		||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -345,34 +347,44 @@ def qwen2_attention_forward_origin(
 | 
			
		|||
    key_states = repeat_kv(key_states, self.num_key_value_groups)
 | 
			
		||||
    value_states = repeat_kv(value_states, self.num_key_value_groups)
 | 
			
		||||
 | 
			
		||||
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
    if not self.training and not hidden_states.requires_grad and \
 | 
			
		||||
            use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
 | 
			
		||||
        import linear_fp16_esimd
 | 
			
		||||
        attn_output = linear_fp16_esimd.sdp_forward(query_states,
 | 
			
		||||
                                                    key_states,
 | 
			
		||||
                                                    value_states)
 | 
			
		||||
        attn_output = attn_output.view(query_states.shape)
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
    else:
 | 
			
		||||
        attn_weights = torch.matmul(query_states,
 | 
			
		||||
                                    key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
 | 
			
		||||
                      ("Attention weights should be of size "
 | 
			
		||||
                       f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
 | 
			
		||||
                       "but is {attn_weights.size()}"))
 | 
			
		||||
        invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
 | 
			
		||||
                          ("Attention weights should be of size "
 | 
			
		||||
                           f"{(bsz, self.num_heads, q_len, kv_seq_len)},"
 | 
			
		||||
                           "but is {attn_weights.size()}"))
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
 | 
			
		||||
                          (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
 | 
			
		||||
                           f" but is {attention_mask.size()}"))
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len),
 | 
			
		||||
                              (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}"
 | 
			
		||||
                               f" but is {attention_mask.size()}"))
 | 
			
		||||
 | 
			
		||||
        attn_weights = attn_weights + attention_mask
 | 
			
		||||
            attn_weights = attn_weights + attention_mask
 | 
			
		||||
 | 
			
		||||
    # upcast attention to fp32
 | 
			
		||||
    attn_weights = \
 | 
			
		||||
        nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
    attn_weights = nn.functional.dropout(attn_weights,
 | 
			
		||||
                                         p=self.attention_dropout,
 | 
			
		||||
                                         training=self.training)
 | 
			
		||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
        # upcast attention to fp32
 | 
			
		||||
        attn_weights = \
 | 
			
		||||
            nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 | 
			
		||||
        attn_weights = nn.functional.dropout(attn_weights,
 | 
			
		||||
                                             p=self.attention_dropout,
 | 
			
		||||
                                             training=self.training)
 | 
			
		||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
			
		||||
 | 
			
		||||
    invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
                      "`attn_output` should be of size "
 | 
			
		||||
                      f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
			
		||||
                      f" but is {attn_output.size()}")
 | 
			
		||||
        invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim),
 | 
			
		||||
                          "`attn_output` should be of size "
 | 
			
		||||
                          f"{(bsz, self.num_heads, q_len, self.head_dim)},"
 | 
			
		||||
                          f" but is {attn_output.size()}")
 | 
			
		||||
 | 
			
		||||
    attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
        attn_output = attn_output.transpose(1, 2).contiguous()
 | 
			
		||||
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
    attn_output = self.o_proj(attn_output)
 | 
			
		||||
| 
						 | 
				
			
			@ -380,7 +392,7 @@ def qwen2_attention_forward_origin(
 | 
			
		|||
    if not output_attentions:
 | 
			
		||||
        attn_weights = None
 | 
			
		||||
 | 
			
		||||
    return attn_output, attn_weights, past_key_value
 | 
			
		||||
    return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_sdpa_attention_forward(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue