optimize siglip attention on arc (#12569)

This commit is contained in:
Yishuo Wang 2024-12-18 14:19:43 +08:00 committed by GitHub
parent 1a2ab12876
commit a4eb561f36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -53,6 +53,17 @@ def siglip_attention_forward(
qkv = qkv.transpose(1, 2) qkv = qkv.transpose(1, 2)
query_states, key_states, value_states = qkv.chunk(3, dim=1) query_states, key_states, value_states = qkv.chunk(3, dim=1)
from ipex_llm.transformers.utils import get_xpu_device_type
if (
self.head_dim == 72
and get_xpu_device_type(query_states) in ["arc", "flex"] and
query_states.dtype in [torch.float, torch.half]
):
import xe_addons
attn_weights = None
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
value_states, attention_softmax)
else:
query_states, key_states, value_states = padding_qkv_hd( query_states, key_states, value_states = padding_qkv_hd(
query_states, key_states, value_states, query_states, key_states, value_states,
72, 80 72, 80