small improvement (#12359)
This commit is contained in:
parent
71ea539351
commit
ad68c56573
1 changed files with 2 additions and 6 deletions
|
|
@ -167,12 +167,8 @@ def qwen2_model_forward(
|
||||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
# ipex-llm changes start: don't generate `attention_mask` in specific cases
|
# ipex-llm changes start: don't generate `attention_mask` in decode phase
|
||||||
if seq_length == 1 or batch_size == 1 and use_sdp_causal(
|
if seq_length == 1:
|
||||||
seq_length, seq_length + past_key_values_length,
|
|
||||||
self.config.hidden_size // self.config.num_attention_heads,
|
|
||||||
inputs_embeds, self.training
|
|
||||||
):
|
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
elif self._attn_implementation == "flash_attention_2":
|
elif self._attn_implementation == "flash_attention_2":
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue