Fix wrong output for Llama models on CPU (#10742)

This commit is contained in:
Guancheng Fu 2024-04-18 11:07:27 +08:00 committed by GitHub
parent e764f9b1b1
commit 31ea2f9a9f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1335,10 +1335,22 @@ def llama_attention_forward_4_36_original(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# otherwise, use native attention
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that
# does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
else:
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
attention_mask,
bsz, q_len, kv_seq_len,
self.head_dim, self.num_heads, output_attentions)
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
if attn_output.size() != attn_output_size: