Update bert optimization to fit higher transformers/torch version (#13006)
This commit is contained in:
		
							parent
							
								
									27d669210f
								
							
						
					
					
						commit
						374747b492
					
				
					 2 changed files with 5 additions and 1 deletions
				
			
		| 
						 | 
					@ -1275,6 +1275,8 @@ def _optimize_post(model):
 | 
				
			||||||
                convert_forward(model,
 | 
					                convert_forward(model,
 | 
				
			||||||
                                module.BertSelfAttention,
 | 
					                                module.BertSelfAttention,
 | 
				
			||||||
                                self_attention_forward)
 | 
					                                self_attention_forward)
 | 
				
			||||||
 | 
					                if hasattr(module, "BertSdpaSelfAttention"):
 | 
				
			||||||
 | 
					                    convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
 | 
				
			||||||
                convert_forward(model,
 | 
					                convert_forward(model,
 | 
				
			||||||
                                module.BertEncoder,
 | 
					                                module.BertEncoder,
 | 
				
			||||||
                                encoder_forward)
 | 
					                                encoder_forward)
 | 
				
			||||||
| 
						 | 
					@ -1863,6 +1865,8 @@ def _optimize_post(model):
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.BertSelfAttention,
 | 
					                        module.BertSelfAttention,
 | 
				
			||||||
                        self_attention_forward)
 | 
					                        self_attention_forward)
 | 
				
			||||||
 | 
					        if hasattr(module, "BertSdpaSelfAttention"):
 | 
				
			||||||
 | 
					            convert_forward(model, module.BertSdpaSelfAttention, self_attention_forward)
 | 
				
			||||||
        convert_forward(model,
 | 
					        convert_forward(model,
 | 
				
			||||||
                        module.BertEncoder,
 | 
					                        module.BertEncoder,
 | 
				
			||||||
                        encoder_forward)
 | 
					                        encoder_forward)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,7 +119,7 @@ def encoder_forward(
 | 
				
			||||||
    output_hidden_states: Optional[bool] = False,
 | 
					    output_hidden_states: Optional[bool] = False,
 | 
				
			||||||
    return_dict: Optional[bool] = True,
 | 
					    return_dict: Optional[bool] = True,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    if not attention_mask.any():
 | 
					    if attention_mask and not attention_mask.any():
 | 
				
			||||||
        attention_mask = None
 | 
					        attention_mask = None
 | 
				
			||||||
    return BertEncoder.forward(
 | 
					    return BertEncoder.forward(
 | 
				
			||||||
        self=self,
 | 
					        self=self,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue