Qwen2 fp16 sdp (#10427)

* qwen2 sdp and refine

* update

* update

* fix style

* remove use_flash_attention
This commit is contained in:
Xin Qiu 2024-03-15 13:12:03 +08:00 committed by GitHub
parent 1315150e64
commit 24473e331a
3 changed files with 52 additions and 51 deletions

View file

@ -604,20 +604,19 @@ def llama_attention_forward_4_31_original(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
fsdp_flag = not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask)
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if fsdp_flag: if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask):
attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
key_states.to(device, dtype=torch.float16), key_states.to(device, dtype=torch.float16),
value_states.to(device, dtype=torch.float16), value_states.to(device, dtype=torch.float16),
is_causal=True) is_causal=True)
attn_weights = None attn_weights = None
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states, attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states, key_states,
@ -1249,29 +1248,20 @@ def llama_attention_forward_4_36_original(
past_key_value.key_cache[self.layer_idx] = key_states past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states past_key_value.value_cache[self.layer_idx] = value_states
if not self.training and not hidden_states.requires_grad:
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, key_states = repeat_kv(key_states, self.num_key_value_groups)
dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
if fsdp_flag: if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask):
# now only use flash attention for first token # now only use flash attention for first token
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
key_states, key_states.to(device, dtype=torch.float16),
value_states, value_states.to(device, dtype=torch.float16),
is_causal=True) is_causal=True)
attn_weights = None attn_weights = None
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states, attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states, key_states,

View file

@ -273,10 +273,8 @@ def qwen_attention_forward_original(
if not decoding_fast_path: if not decoding_fast_path:
query = query.transpose(1, 2) query = query.transpose(1, 2)
fsdp_flag = not self.training and not hidden_states.requires_grad and \ if not self.training and not hidden_states.requires_grad and \
use_flash_attention(query, key) use_flash_attention(query, key):
if fsdp_flag:
attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16), attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16),
key.to(device, dtype=torch.float16), key.to(device, dtype=torch.float16),
value.to(device, dtype=torch.float16), value.to(device, dtype=torch.float16),
@ -284,7 +282,8 @@ def qwen_attention_forward_original(
attn_output = attn_output.view(query.shape) attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
attn_weights = None attn_weights = None
elif use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): elif not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
import linear_fp16_esimd import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query, attn_output = linear_fp16_esimd.sdp_forward(query,
key, key,

View file

@ -43,6 +43,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from bigdl.llm.transformers.models.llama import repeat_kv from bigdl.llm.transformers.models.llama import repeat_kv
from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache
@ -51,6 +52,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.transformers.kv import DynamicFp8Cache from bigdl.llm.transformers.kv import DynamicFp8Cache
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb
@ -345,7 +347,17 @@ def qwen2_attention_forward_origin(
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if not self.training and not hidden_states.requires_grad and \
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states,
value_states)
attn_output = attn_output.view(query_states.shape)
attn_weights = None
else:
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len),
("Attention weights should be of size " ("Attention weights should be of size "
@ -380,7 +392,7 @@ def qwen2_attention_forward_origin(
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output.to(hidden_states.dtype), attn_weights, past_key_value
def qwen2_sdpa_attention_forward( def qwen2_sdpa_attention_forward(