From 374747b4922d1c6b8da1f62cc4c9bd8280507979 Mon Sep 17 00:00:00 2001 From: Yuwen Hu <54161268+Oscilloscope98@users.noreply.github.com> Date: Tue, 25 Mar 2025 16:12:03 +0800 Subject: [PATCH] Update bert optimization to fit higher transformers/torch version (#13006) --- python/llm/src/ipex_llm/transformers/convert.py | 4 ++++ python/llm/src/ipex_llm/transformers/models/bert.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index d61d804d..4a025081 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1275,6 +1275,8 @@ def _optimize_post(model): convert_forward(model, module.BertSelfAttention, self_attention_forward) + if hasattr(module, "BertSdpaSelfAttention"): + convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward) convert_forward(model, module.BertEncoder, encoder_forward) @@ -1863,6 +1865,8 @@ def _optimize_post(model): convert_forward(model, module.BertSelfAttention, self_attention_forward) + if hasattr(module, "BertSdpaSelfAttention"): + convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward) convert_forward(model, module.BertEncoder, encoder_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/bert.py b/python/llm/src/ipex_llm/transformers/models/bert.py index 810d89b4..2f1c4f56 100644 --- a/python/llm/src/ipex_llm/transformers/models/bert.py +++ b/python/llm/src/ipex_llm/transformers/models/bert.py @@ -119,7 +119,7 @@ def encoder_forward( output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ): - if not attention_mask.any(): + if attention_mask and not attention_mask.any(): attention_mask = None return BertEncoder.forward( self=self,