optimize siglip attention again (#12578)

This commit is contained in:
Yishuo Wang 2024-12-19 13:40:48 +08:00 committed by GitHub
parent e0921f80c1
commit 4540424271
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 41 additions and 5 deletions

View file

@ -237,27 +237,50 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value:
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device) mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
# compute # compute
# import xe_addons
# if is_causal:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
# elif seq_length != kv_length and seq_length <= 32:
# # todo: add scale support
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
# else:
# attn_output = xe_addons.sdp(query, key, value, mask)
# else:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
import xe_addons import xe_addons
if is_causal: if is_causal:
if key.dtype == torch.uint8: if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale) attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
else: else:
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) attn_output = xe_addons.sdp_causal(query, key, value, mask)
elif seq_length != kv_length and seq_length <= 32: elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
if key.dtype == torch.uint8: if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask) attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else: else:
attn_output = xe_addons.sdp(query, key, value, mask) attn_output = xe_addons.sdp(query, key, value, mask)
else: else:
if key.dtype == torch.uint8: if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else: else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale) attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
return attn_output return attn_output
else: else:
mask = mask[..., :seq_length, :kv_length] if mask is not None else None mask = mask[..., :seq_length, :kv_length] if mask is not None else None
from ipex_llm.transformers.models.utils import repeat_kv
if n_heads != n_kv_heads:
key = repeat_kv(key, n_heads // n_kv_heads)
value = repeat_kv(value, n_heads // n_kv_heads)
return torch.nn.functional.scaled_dot_product_attention( return torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale query, key, value, mask, is_causal=is_causal, scale=scale
) )

View file

@ -59,6 +59,10 @@ def siglip_attention_forward(
and get_xpu_device_type(query_states) in ["arc", "flex"] and and get_xpu_device_type(query_states) in ["arc", "flex"] and
query_states.dtype in [torch.float, torch.half] query_states.dtype in [torch.float, torch.half]
): ):
n_heads, kv_length = query_states.size(1), key_states.size(2)
from ipex_llm.transformers.models.common import prepare_mask
attention_mask = prepare_mask(attention_mask, bsz, n_heads, q_len, kv_length,
False, query_states.dtype, query_states.device)
import xe_addons import xe_addons
attn_weights = None attn_weights = None
attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states, attn_output = xe_addons.siglip_sdp_non_causal(query_states, key_states,

View file

@ -388,6 +388,15 @@ def fp16_fusion_check(proj, x, training):
return True return True
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def update_past_key_value(past_key_value, key_states, value_states, def update_past_key_value(past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device): kv_seq_len, use_quantize_kv, device):
bsz, num_heads, _, head_dim = key_states.shape bsz, num_heads, _, head_dim = key_states.shape