fix qwen2 attention_mask slice (#12307)
This commit is contained in:
parent
3df6195cb0
commit
b9853f98b3
1 changed files with 3 additions and 0 deletions
|
|
@ -560,6 +560,9 @@ def qwen2_attention_forward(
|
|||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, :kv_seq_len]
|
||||
|
||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||
import xe_addons
|
||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||
|
|
|
|||
Loading…
Reference in a new issue