small improvement (#12359)

This commit is contained in:
Yishuo Wang 2024-11-07 15:57:41 +08:00 committed by GitHub
parent 71ea539351
commit ad68c56573
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -167,12 +167,8 @@ def qwen2_model_forward(
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
# ipex-llm changes start: don't generate `attention_mask` in specific cases # ipex-llm changes start: don't generate `attention_mask` in decode phase
if seq_length == 1 or batch_size == 1 and use_sdp_causal( if seq_length == 1:
seq_length, seq_length + past_key_values_length,
self.config.hidden_size // self.config.num_attention_heads,
inputs_embeds, self.training
):
attention_mask = None attention_mask = None
# ipex-llm changes end # ipex-llm changes end
elif self._attn_implementation == "flash_attention_2": elif self._attn_implementation == "flash_attention_2":