diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index d61d804d..4a025081 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1275,6 +1275,8 @@ def _optimize_post(model): convert_forward(model, module.BertSelfAttention, self_attention_forward) + if hasattr(module, "BertSdpaSelfAttention"): + convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward) convert_forward(model, module.BertEncoder, encoder_forward) @@ -1863,6 +1865,8 @@ def _optimize_post(model): convert_forward(model, module.BertSelfAttention, self_attention_forward) + if hasattr(module, "BertSdpaSelfAttention"): + convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward) convert_forward(model, module.BertEncoder, encoder_forward) diff --git a/python/llm/src/ipex_llm/transformers/models/bert.py b/python/llm/src/ipex_llm/transformers/models/bert.py index 810d89b4..2f1c4f56 100644 --- a/python/llm/src/ipex_llm/transformers/models/bert.py +++ b/python/llm/src/ipex_llm/transformers/models/bert.py @@ -119,7 +119,7 @@ def encoder_forward( output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ): - if not attention_mask.any(): + if attention_mask and not attention_mask.any(): attention_mask = None return BertEncoder.forward( self=self,