use new fp32 softmax kernel (#11776)

This commit is contained in:
Yishuo Wang 2024-08-13 14:48:11 +08:00 committed by GitHub
parent 23d3acdc77
commit aa861df066
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 5 deletions

View file

@ -42,8 +42,9 @@ def siglip_attention_forward(
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 import xe_addons
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) xe_addons.attn_softmax_inplaced(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)

View file

@ -184,9 +184,9 @@ def attention_forward(
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
# upcast attention to fp32 import xe_addons
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, xe_addons.attn_softmax_inplaced(attn_weights)
dtype=torch.float32).to(value_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training) training=self.training)
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)