hotfix native_sdp ut (#11046)

* hotfix native_sdp

* update
This commit is contained in:
SONG Ge 2024-05-16 17:15:37 +08:00 committed by GitHub
parent 6be70283b7
commit 16b2a418be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -855,11 +855,13 @@ def llama_attention_selective_batching_forward_4_31(
current_key_states = repeat_kv(current_key_states, self.num_key_value_groups)
current_value_states = repeat_kv(current_value_states, self.num_key_value_groups)
cache_position = None
current_query_states = query_states[batch: batch + 1, :, :, :]
attn_output, attn_weights = native_sdp(current_query_states,
current_key_states,
current_value_states,
attention_mask[batch],
cache_position,
1,
1,
current_kv_len,
@ -901,10 +903,12 @@ def llama_attention_selective_batching_forward_4_31(
if isinstance(attention_mask, list):
# For decoding fast path
attention_mask = attention_mask[0]
cache_position = None
attn_output, attn_weights = native_sdp(query_states,
key_states,
value_states,
attention_mask,
cache_position,
bsz,
q_len,
kv_seq_len,
@ -1445,7 +1449,7 @@ def llama_attention_forward_4_38_original(
def native_sdp(query, key, value, attention_mask, cache_position,
bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions):
if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions):
return native_sdp_split_qkv_tensor(query, key, value, attention_mask,
return native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_position,
bsz, q_len, kv_seq_len, head_dim, num_heads)
else:
attn_weights = torch.matmul(query.to(key.dtype),
@ -1502,14 +1506,14 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, cache_positio
if cache_position is not None:
# for transformers 4.38.0
causal_mask = attention_mask[:, :, cache_position, : kv_seq_len]
attn_weights = attn_weights + causal_mask
attn_weights_split = attn_weights_split + causal_mask
else:
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
if attention_mask.size() != attn_mask_size:
invalidInputError(False,
f"Attention mask should be of size {attn_mask_size}, "
f"but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
attn_weights_split = attn_weights_split + attention_mask
attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1)
attn_outputs.append(torch.matmul(attn_weights_split, v))
attn_output = torch.cat(attn_outputs, dim=1)
@ -1767,8 +1771,9 @@ def llama_attention_fast_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
cache_position = None
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
attention_mask, cache_position,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)