Qwen fp16 sdp (#10401)
* qwen sdp * fix * update * update * update sdp * update * fix style check * add to origin type
This commit is contained in:
parent
1c0f7ed3fa
commit
cda38f85a9
2 changed files with 34 additions and 19 deletions
|
|
@ -604,25 +604,17 @@ def llama_attention_forward_4_31_original(
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if not self.training and not hidden_states.requires_grad:
|
||||
fsdp_flag = use_flash_attention(query_states, key_states, attention_mask)
|
||||
else:
|
||||
fsdp_flag = False
|
||||
if fsdp_flag:
|
||||
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||
else:
|
||||
attention_dtype = original_dtype
|
||||
fsdp_flag = not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if fsdp_flag:
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||
key_states,
|
||||
value_states,
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
|
||||
key_states.to(device, dtype=torch.float16),
|
||||
value_states.to(device, dtype=torch.float16),
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv
|
|||
from bigdl.llm.transformers.models.utils import rotate_half, SILU
|
||||
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu
|
||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
|
||||
|
|
@ -131,6 +132,8 @@ def qwen_attention_forward_original(
|
|||
"flash attn and kv_cache quantization are not supported")
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
device = hidden_states.device
|
||||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states)
|
||||
decoding_fast_path = (use_fuse_rope and bsz * q_len == 1)
|
||||
|
|
@ -270,6 +273,26 @@ def qwen_attention_forward_original(
|
|||
if not decoding_fast_path:
|
||||
query = query.transpose(1, 2)
|
||||
|
||||
fsdp_flag = not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query, key)
|
||||
|
||||
if fsdp_flag:
|
||||
attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16),
|
||||
key.to(device, dtype=torch.float16),
|
||||
value.to(device, dtype=torch.float16),
|
||||
is_causal=True)
|
||||
attn_output = attn_output.view(query.shape)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(q_len, key.shape[2], self.head_dim, query):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query,
|
||||
key,
|
||||
value)
|
||||
attn_output = attn_output.view(query.shape)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_weight = None
|
||||
else:
|
||||
attn_output, attn_weight = self._attn(
|
||||
query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask
|
||||
)
|
||||
|
|
@ -278,7 +301,7 @@ def qwen_attention_forward_original(
|
|||
attn_output, self.num_heads, self.head_dim
|
||||
)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
attn_output = self.c_proj(context_layer).to(original_dtype)
|
||||
|
||||
if use_cache:
|
||||
outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))
|
||||
|
|
|
|||
Loading…
Reference in a new issue