Update bert optimization to fit higher transformers/torch version (#13006)
This commit is contained in:
parent
27d669210f
commit
374747b492
2 changed files with 5 additions and 1 deletions
|
|
@ -1275,6 +1275,8 @@ def _optimize_post(model):
|
|||
convert_forward(model,
|
||||
module.BertSelfAttention,
|
||||
self_attention_forward)
|
||||
if hasattr(module, "BertSdpaSelfAttention"):
|
||||
convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
|
||||
convert_forward(model,
|
||||
module.BertEncoder,
|
||||
encoder_forward)
|
||||
|
|
@ -1863,6 +1865,8 @@ def _optimize_post(model):
|
|||
convert_forward(model,
|
||||
module.BertSelfAttention,
|
||||
self_attention_forward)
|
||||
if hasattr(module, "BertSdpaSelfAttention"):
|
||||
convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
|
||||
convert_forward(model,
|
||||
module.BertEncoder,
|
||||
encoder_forward)
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def encoder_forward(
|
|||
output_hidden_states: Optional[bool] = False,
|
||||
return_dict: Optional[bool] = True,
|
||||
):
|
||||
if not attention_mask.any():
|
||||
if attention_mask and not attention_mask.any():
|
||||
attention_mask = None
|
||||
return BertEncoder.forward(
|
||||
self=self,
|
||||
|
|
|
|||
Loading…
Reference in a new issue