use new fp16 sdp in llama and mistral (#10734)
This commit is contained in:
parent
019293e1b9
commit
8086554d33
2 changed files with 42 additions and 36 deletions
|
|
@ -647,12 +647,11 @@ def llama_attention_forward_4_31_original(
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask):
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
|
||||
key_states.to(device, dtype=torch.float16),
|
||||
value_states.to(device, dtype=torch.float16),
|
||||
|
|
@ -660,13 +659,14 @@ def llama_attention_forward_4_31_original(
|
|||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
# otherwise, use native attention
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
|
|
@ -1305,12 +1305,11 @@ def llama_attention_forward_4_36_original(
|
|||
past_key_value.key_cache[self.layer_idx] = key_states
|
||||
past_key_value.value_cache[self.layer_idx] = value_states
|
||||
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if not self.training and not hidden_states.requires_grad and \
|
||||
use_flash_attention(query_states, key_states, attention_mask):
|
||||
# now only use flash attention for first token
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
|
||||
key_states.to(device, dtype=torch.float16),
|
||||
|
|
@ -1319,13 +1318,14 @@ def llama_attention_forward_4_36_original(
|
|||
attn_weights = None
|
||||
elif not self.training and not hidden_states.requires_grad and \
|
||||
use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
# otherwise, use native attention
|
||||
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||
attention_mask,
|
||||
|
|
|
|||
|
|
@ -495,13 +495,12 @@ def mistral_attention_forward_original(
|
|||
else:
|
||||
attention_dtype = original_dtype
|
||||
|
||||
if fsdp_flag:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
|
||||
if fsdp_flag:
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||
key_states,
|
||||
value_states,
|
||||
|
|
@ -510,15 +509,19 @@ def mistral_attention_forward_original(
|
|||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
else:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
|
@ -885,13 +888,12 @@ def mistral_attention_forward_4_36_original(
|
|||
else:
|
||||
attention_dtype = original_dtype
|
||||
|
||||
if fsdp_flag:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
|
||||
if fsdp_flag:
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||
key_states,
|
||||
value_states,
|
||||
|
|
@ -900,15 +902,19 @@ def mistral_attention_forward_4_36_original(
|
|||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
value_states)
|
||||
# new fp16 sdp doesn't require repeat_kv
|
||||
import linear_q4_0
|
||||
attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
else:
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
|
|
|||
Loading…
Reference in a new issue