qwen2 sdpa small fix (#11261)
This commit is contained in:
parent
ea0d03fd28
commit
2623944604
1 changed files with 3 additions and 0 deletions
|
|
@ -337,6 +337,9 @@ def qwen2_attention_forward(
|
||||||
is_causal=self.is_causal and attention_mask is None and q_len > 1)
|
is_causal=self.is_causal and attention_mask is None and q_len > 1)
|
||||||
elif not self.training and not hidden_states.requires_grad and \
|
elif not self.training and not hidden_states.requires_grad and \
|
||||||
use_flash_attention(query_states, key_states, attention_mask):
|
use_flash_attention(query_states, key_states, attention_mask):
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_output = sdpa(query_states.to(device, dtype=torch.float16),
|
attn_output = sdpa(query_states.to(device, dtype=torch.float16),
|
||||||
key_states.to(device, dtype=torch.float16),
|
key_states.to(device, dtype=torch.float16),
|
||||||
value_states.to(device, dtype=torch.float16),
|
value_states.to(device, dtype=torch.float16),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue