Update bert optimization to fit higher transformers/torch version (#13006)

This commit is contained in:
Yuwen Hu 2025-03-25 16:12:03 +08:00 committed by GitHub
parent 27d669210f
commit 374747b492
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 1 deletions

View file

@ -1275,6 +1275,8 @@ def _optimize_post(model):
convert_forward(model,
module.BertSelfAttention,
self_attention_forward)
if hasattr(module, "BertSdpaSelfAttention"):
convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
convert_forward(model,
module.BertEncoder,
encoder_forward)
@ -1863,6 +1865,8 @@ def _optimize_post(model):
convert_forward(model,
module.BertSelfAttention,
self_attention_forward)
if hasattr(module, "BertSdpaSelfAttention"):
convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
convert_forward(model,
module.BertEncoder,
encoder_forward)

View file

@ -119,7 +119,7 @@ def encoder_forward(
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
):
if not attention_mask.any():
if attention_mask and not attention_mask.any():
attention_mask = None
return BertEncoder.forward(
self=self,