optimize siglip attention again (#12578)
This commit is contained in:
parent
e0921f80c1
commit
4540424271
3 changed files with 41 additions and 5 deletions
|
|
@ -237,27 +237,50 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
|
||||||
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
||||||
|
|
||||||
# compute
|
# compute
|
||||||
|
# import xe_addons
|
||||||
|
# if is_causal:
|
||||||
|
# if key.dtype == torch.uint8:
|
||||||
|
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||||
|
# else:
|
||||||
|
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||||
|
# elif seq_length != kv_length and seq_length <= 32:
|
||||||
|
# # todo: add scale support
|
||||||
|
# if key.dtype == torch.uint8:
|
||||||
|
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||||
|
# else:
|
||||||
|
# attn_output = xe_addons.sdp(query, key, value, mask)
|
||||||
|
# else:
|
||||||
|
# if key.dtype == torch.uint8:
|
||||||
|
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||||
|
# else:
|
||||||
|
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||||
|
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if is_causal:
|
if is_causal:
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
attn_output = xe_addons.sdp_causal(query, key, value, mask)
|
||||||
elif seq_length != kv_length and seq_length <= 32:
|
elif seq_length != kv_length and seq_length <= 32:
|
||||||
# todo: add scale support
|
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||||
else:
|
else:
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
else:
|
else:
|
||||||
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
|
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
|
||||||
|
|
||||||
|
from ipex_llm.transformers.models.utils import repeat_kv
|
||||||
|
if n_heads != n_kv_heads:
|
||||||
|
key = repeat_kv(key, n_heads // n_kv_heads)
|
||||||
|
value = repeat_kv(value, n_heads // n_kv_heads)
|
||||||
|
|
||||||
return torch.nn.functional.scaled_dot_product_attention(
|
return torch.nn.functional.scaled_dot_product_attention(
|
||||||
query, key, value, mask, is_causal=is_causal, scale=scale
|
query, key, value, mask, is_causal=is_causal, scale=scale
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,10 @@ def siglip_attention_forward(
|
||||||
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
||||||
query_states.dtype in [torch.float, torch.half]
|
query_states.dtype in [torch.float, torch.half]
|
||||||
):
|
):
|
||||||
|
n_heads, kv_length = query_states.size(1), key_states.size(2)
|
||||||
|
from ipex_llm.transformers.models.common import prepare_mask
|
||||||
|
attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
|
||||||
|
False, query_states.dtype, query_states.device)
|
||||||
import xe_addons
|
import xe_addons
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
||||||
|
|
|
||||||
|
|
@ -388,6 +388,15 @@ def fp16_fusion_check(proj, x, training):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
|
||||||
|
n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
def update_past_key_value(past_key_value, key_states, value_states,
|
def update_past_key_value(past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device):
|
kv_seq_len, use_quantize_kv, device):
|
||||||
bsz, num_heads, _, head_dim = key_states.shape
|
bsz, num_heads, _, head_dim = key_states.shape
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue