optimize siglip attention on arc (#12569)
This commit is contained in:
parent
1a2ab12876
commit
a4eb561f36
1 changed files with 27 additions and 16 deletions
|
|
@ -53,6 +53,17 @@ def siglip_attention_forward(
|
||||||
qkv = qkv.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=1)
|
||||||
|
|
||||||
|
from ipex_llm.transformers.utils import get_xpu_device_type
|
||||||
|
if (
|
||||||
|
self.head_dim == 72
|
||||||
|
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
||||||
|
query_states.dtype in [torch.float, torch.half]
|
||||||
|
):
|
||||||
|
import xe_addons
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
||||||
|
value_states, attention_softmax)
|
||||||
|
else:
|
||||||
query_states, key_states, value_states = padding_qkv_hd(
|
query_states, key_states, value_states = padding_qkv_hd(
|
||||||
query_states, key_states, value_states,
|
query_states, key_states, value_states,
|
||||||
72, 80
|
72, 80
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue