optimize siglip attention on arc (#12569)

This commit is contained in:
Yishuo Wang 2024-12-18 14:19:43 +08:00 committed by GitHub
parent 1a2ab12876
commit a4eb561f36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -53,28 +53,39 @@ def siglip_attention_forward(
qkv = qkv.transpose(1, 2) qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1) query_states, key_states, value_states = qkv.chunk(3, dim=1)
query_states, key_states, value_states = padding_qkv_hd( from ipex_llm.transformers.utils import get_xpu_device_type
query_states, key_states, value_states, if (
72, 80 self.head_dim == 72
) and get_xpu_device_type(query_states) in ["arc", "flex"] and
query_states.dtype in [torch.float, torch.half]
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype): ):
import xe_addons import xe_addons
attn_weights = None attn_weights = None
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(), attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
value_states.contiguous(), attention_mask) value_states, attention_softmax)
else: else:
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) query_states, key_states, value_states = padding_qkv_hd(
if attention_mask is not None: query_states, key_states, value_states,
attn_weights = attn_weights + attention_mask 72, 80
)
attn_weights = attention_softmax(attn_weights) if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
import xe_addons
attn_weights = None
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
value_states.contiguous(), attention_mask)
else:
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.dropout(attn_weights, attn_weights = attention_softmax(attn_weights)
p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output[:, :, :, :self.head_dim] attn_weights = torch.nn.functional.dropout(attn_weights,
p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output[:, :, :, :self.head_dim]
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim) attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)