optimize siglip attention on arc (#12569)
This commit is contained in:
parent
1a2ab12876
commit
a4eb561f36
1 changed files with 27 additions and 16 deletions
|
|
@ -53,28 +53,39 @@ def siglip_attention_forward(
|
|||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||
|
||||
query_states, key_states, value_states = padding_qkv_hd(
|
||||
query_states, key_states, value_states,
|
||||
72, 80
|
||||
)
|
||||
|
||||
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||
if (
|
||||
self.head_dim == 72
|
||||
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
||||
query_states.dtype in [torch.float, torch.half]
|
||||
):
|
||||
import xe_addons
|
||||
attn_weights = None
|
||||
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
||||
value_states.contiguous(), attention_mask)
|
||||
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
||||
value_states, attention_softmax)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
query_states, key_states, value_states = padding_qkv_hd(
|
||||
query_states, key_states, value_states,
|
||||
72, 80
|
||||
)
|
||||
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
||||
import xe_addons
|
||||
attn_weights = None
|
||||
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
||||
value_states.contiguous(), attention_mask)
|
||||
else:
|
||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||
p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_weights = attention_softmax(attn_weights)
|
||||
|
||||
attn_output = attn_output[:, :, :, :self.head_dim]
|
||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||
p=self.dropout, training=self.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
attn_output = attn_output[:, :, :, :self.head_dim]
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||
|
|
|
|||
Loading…
Reference in a new issue