qwen2 sdpa small fix (#11261)

This commit is contained in:
Yishuo Wang 2024-06-07 14:42:18 +08:00 committed by GitHub
parent ea0d03fd28
commit 2623944604
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -337,6 +337,9 @@ def qwen2_attention_forward(
is_causal=self.is_causal and attention_mask is None and q_len > 1)
elif not self.training and not hidden_states.requires_grad and \
use_flash_attention(query_states, key_states, attention_mask):
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = sdpa(query_states.to(device, dtype=torch.float16),
key_states.to(device, dtype=torch.float16),
value_states.to(device, dtype=torch.float16),