fix qwen2 attention_mask slice (#12307)

This commit is contained in:
Yishuo Wang 2024-10-31 17:00:05 +08:00 committed by GitHub
parent 3df6195cb0
commit b9853f98b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -560,6 +560,9 @@ def qwen2_attention_forward(
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :kv_seq_len]
if should_use_fuse_rope(hidden_states, position_ids, self.training):
import xe_addons
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,