use sdp fp8 causal kernel (#11023)

This commit is contained in:
Yishuo Wang 2024-05-15 10:22:35 +08:00 committed by GitHub
parent c34f85e7d0
commit fad1dbaf60
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -137,38 +137,39 @@ def attention_forward(
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
if (isinstance(past_key_value, DynamicFp8Cache) and
use_sdp_fp8(q_len, kv_seq_len, query_states)):
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import linear_q4_0
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask)
elif (isinstance(past_key_value, DynamicNormalCache) and
use_sdp(q_len, kv_seq_len, self.head_dim, query_states)):
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training):
import linear_q4_0
attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask)
if isinstance(past_key_value, DynamicFp8Cache):
attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states)
else:
attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
else:
if isinstance(past_key_value, DynamicFp8Cache):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
if use_sdp_causal(q_len, kv_seq_len, query_states, self.training):
import linear_q4_0
attn_output = linear_q4_0.sdp_causal(query_states, key_states, value_states)
else:
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = torch.matmul(query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)