fix qwen2 attention_mask slice (#12307)
This commit is contained in:
parent
3df6195cb0
commit
b9853f98b3
1 changed files with 3 additions and 0 deletions
|
|
@ -560,6 +560,9 @@ def qwen2_attention_forward(
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask[:, :, :, :kv_seq_len]
|
||||||
|
|
||||||
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
if should_use_fuse_rope(hidden_states, position_ids, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue