optimize siglip attention on arc (#12569)
This commit is contained in:
parent
1a2ab12876
commit
a4eb561f36
1 changed files with 27 additions and 16 deletions
|
|
@ -53,6 +53,17 @@ def siglip_attention_forward(
|
|||
qkv = qkv.transpose(1, 2)
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||
|
||||
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||
if (
|
||||
self.head_dim == 72
|
||||
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
||||
query_states.dtype in [torch.float, torch.half]
|
||||
):
|
||||
import xe_addons
|
||||
attn_weights = None
|
||||
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
||||
value_states, attention_softmax)
|
||||
else:
|
||||
query_states, key_states, value_states = padding_qkv_hd(
|
||||
query_states, key_states, value_states,
|
||||
72, 80
|
||||
|
|
|
|||
Loading…
Reference in a new issue