optimize siglip attention again (#12578)
This commit is contained in:
parent
e0921f80c1
commit
4540424271
3 changed files with 41 additions and 5 deletions
|
|
@ -237,27 +237,50 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
|
|||
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
||||
|
||||
# compute
|
||||
# import xe_addons
|
||||
# if is_causal:
|
||||
# if key.dtype == torch.uint8:
|
||||
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||
# else:
|
||||
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||
# elif seq_length != kv_length and seq_length <= 32:
|
||||
# # todo: add scale support
|
||||
# if key.dtype == torch.uint8:
|
||||
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
# else:
|
||||
# attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
# else:
|
||||
# if key.dtype == torch.uint8:
|
||||
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||
# else:
|
||||
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||
|
||||
import xe_addons
|
||||
if is_causal:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||
attn_output = xe_addons.sdp_causal(query, key, value, mask)
|
||||
elif seq_length != kv_length and seq_length <= 32:
|
||||
# todo: add scale support
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
else:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
|
||||
|
||||
return attn_output
|
||||
else:
|
||||
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
|
||||
|
||||
from ipex_llm.transformers.models.utils import repeat_kv
|
||||
if n_heads != n_kv_heads:
|
||||
key = repeat_kv(key, n_heads // n_kv_heads)
|
||||
value = repeat_kv(value, n_heads // n_kv_heads)
|
||||
|
||||
return torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, mask, is_causal=is_causal, scale=scale
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,6 +59,10 @@ def siglip_attention_forward(
|
|||
and get_xpu_device_type(query_states) in ["arc", "flex"] and
|
||||
query_states.dtype in [torch.float, torch.half]
|
||||
):
|
||||
n_heads, kv_length = query_states.size(1), key_states.size(2)
|
||||
from ipex_llm.transformers.models.common import prepare_mask
|
||||
attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
|
||||
False, query_states.dtype, query_states.device)
|
||||
import xe_addons
|
||||
attn_weights = None
|
||||
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,
|
||||
|
|
|
|||
|
|
@ -388,6 +388,15 @@ def fp16_fusion_check(proj, x, training):
|
|||
return True
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
|
||||
n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def update_past_key_value(past_key_value, key_states, value_states,
|
||||
kv_seq_len, use_quantize_kv, device):
|
||||
bsz, num_heads, _, head_dim = key_states.shape
|
||||
|
|
|
|||
Loading…
Reference in a new issue