use new fp32 softmax kernel (#11776)
This commit is contained in:
		
							parent
							
								
									23d3acdc77
								
							
						
					
					
						commit
						aa861df066
					
				
					 2 changed files with 6 additions and 5 deletions
				
			
		| 
						 | 
					@ -42,8 +42,9 @@ def siglip_attention_forward(
 | 
				
			||||||
    if attention_mask is not None:
 | 
					    if attention_mask is not None:
 | 
				
			||||||
        attn_weights = attn_weights + attention_mask
 | 
					        attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # upcast attention to fp32
 | 
					    import xe_addons
 | 
				
			||||||
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
 | 
					    xe_addons.attn_softmax_inplaced(attn_weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
					    attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 | 
				
			||||||
    attn_output = torch.matmul(attn_weights, value_states)
 | 
					    attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -184,9 +184,9 @@ def attention_forward(
 | 
				
			||||||
        if attention_mask is not None:
 | 
					        if attention_mask is not None:
 | 
				
			||||||
            attn_weights = attn_weights + attention_mask
 | 
					            attn_weights = attn_weights + attention_mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # upcast attention to fp32
 | 
					        import xe_addons
 | 
				
			||||||
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
 | 
					        xe_addons.attn_softmax_inplaced(attn_weights)
 | 
				
			||||||
                                                   dtype=torch.float32).to(value_states.dtype)
 | 
					
 | 
				
			||||||
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
					        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
 | 
				
			||||||
                                                   training=self.training)
 | 
					                                                   training=self.training)
 | 
				
			||||||
        attn_output = torch.matmul(attn_weights, value_states)
 | 
					        attn_output = torch.matmul(attn_weights, value_states)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue