Update bert optimization to fit higher transformers/torch version (#13006)
This commit is contained in:
parent
27d669210f
commit
374747b492
2 changed files with 5 additions and 1 deletions
|
|
@ -1275,6 +1275,8 @@ def _optimize_post(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.BertSelfAttention,
|
module.BertSelfAttention,
|
||||||
self_attention_forward)
|
self_attention_forward)
|
||||||
|
if hasattr(module, "BertSdpaSelfAttention"):
|
||||||
|
convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.BertEncoder,
|
module.BertEncoder,
|
||||||
encoder_forward)
|
encoder_forward)
|
||||||
|
|
@ -1863,6 +1865,8 @@ def _optimize_post(model):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.BertSelfAttention,
|
module.BertSelfAttention,
|
||||||
self_attention_forward)
|
self_attention_forward)
|
||||||
|
if hasattr(module, "BertSdpaSelfAttention"):
|
||||||
|
convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.BertEncoder,
|
module.BertEncoder,
|
||||||
encoder_forward)
|
encoder_forward)
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,7 @@ def encoder_forward(
|
||||||
output_hidden_states: Optional[bool] = False,
|
output_hidden_states: Optional[bool] = False,
|
||||||
return_dict: Optional[bool] = True,
|
return_dict: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
if not attention_mask.any():
|
if attention_mask and not attention_mask.any():
|
||||||
attention_mask = None
|
attention_mask = None
|
||||||
return BertEncoder.forward(
|
return BertEncoder.forward(
|
||||||
self=self,
|
self=self,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue