Qwen fp16 sdp (#10401)

* qwen sdp

* fix

* update

* update

* update sdp

* update

* fix style check

* add to origin type
This commit is contained in:
Xin Qiu 2024-03-15 08:51:50 +08:00 committed by GitHub
parent 1c0f7ed3fa
commit cda38f85a9
2 changed files with 34 additions and 19 deletions

View file

@ -604,25 +604,17 @@ def llama_attention_forward_4_31_original(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
if not self.training and not hidden_states.requires_grad: fsdp_flag = not self.training and not hidden_states.requires_grad and \
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) use_flash_attention(query_states, key_states, attention_mask)
else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, key_states = repeat_kv(key_states, self.num_key_value_groups)
dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
if fsdp_flag: if fsdp_flag:
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
key_states, key_states.to(device, dtype=torch.float16),
value_states, value_states.to(device, dtype=torch.float16),
is_causal=True) is_causal=True)
attn_weights = None attn_weights = None
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):

View file

@ -42,6 +42,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv
from bigdl.llm.transformers.models.utils import rotate_half, SILU from bigdl.llm.transformers.models.utils import rotate_half, SILU
from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.utils.common import invalidInputError, invalidOperationError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -131,6 +132,8 @@ def qwen_attention_forward_original(
"flash attn and kv_cache quantization are not supported") "flash attn and kv_cache quantization are not supported")
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states) use_fuse_rope = should_use_fuse_rope(self, hidden_states)
decoding_fast_path = (use_fuse_rope and bsz * q_len == 1) decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
@ -270,6 +273,26 @@ def qwen_attention_forward_original(
if not decoding_fast_path: if not decoding_fast_path:
query = query.transpose(1, 2) query = query.transpose(1, 2)
fsdp_flag = not self.training and not hidden_states.requires_grad and \
use_flash_attention(query, key)
if fsdp_flag:
attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16),
key.to(device, dtype=torch.float16),
value.to(device, dtype=torch.float16),
is_causal=True)
attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2)
attn_weights = None
elif use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query,
key,
value)
attn_output = attn_output.view(query.shape)
attn_output = attn_output.transpose(1, 2)
attn_weight = None
else:
attn_output, attn_weight = self._attn( attn_output, attn_weight = self._attn(
query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask
) )
@ -278,7 +301,7 @@ def qwen_attention_forward_original(
attn_output, self.num_heads, self.head_dim attn_output, self.num_heads, self.head_dim
) )
attn_output = self.c_proj(context_layer) attn_output = self.c_proj(context_layer).to(original_dtype)
if use_cache: if use_cache:
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2))) outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))