optimize siglip attention on arc (#12569)
This commit is contained in:
parent
1a2ab12876
commit
a4eb561f36
1 changed files with 27 additions and 16 deletions
|
|
@ -53,28 +53,39 @@ def siglip_attention_forward(
|
||||||
qkv = qkv.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
query_states, key_states, value_states = padding_qkv_hd(
|
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||||
query_states, key_states, value_states,
|
if (
|
||||||
72, 80
|
self.head_dim == 72
|
||||||
)
|
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
||||||
|
query_states.dtype in [torch.float, torch.half]
|
||||||
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
||||||
value_states.contiguous(), attention_mask)
|
value_states, attention_softmax)
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
query_states, key_states, value_states = padding_qkv_hd(
|
||||||
if attention_mask is not None:
|
query_states, key_states, value_states,
|
||||||
attn_weights = attn_weights + attention_mask
|
72, 80
|
||||||
|
)
|
||||||
|
|
||||||
attn_weights = attention_softmax(attn_weights)
|
if use_sdp_non_causal(query_states.size(-1), query_states.device, query_states.dtype):
|
||||||
|
import xe_addons
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = xe_addons.sdp_non_causal(query_states, key_states.contiguous(),
|
||||||
|
value_states.contiguous(), attention_mask)
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
attn_weights = torch.nn.functional.dropout(attn_weights,
|
attn_weights = attention_softmax(attn_weights)
|
||||||
p=self.dropout, training=self.training)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
attn_output = attn_output[:, :, :, :self.head_dim]
|
attn_weights = torch.nn.functional.dropout(attn_weights,
|
||||||
|
p=self.dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
attn_output = attn_output[:, :, :, :self.head_dim]
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue