support passing attn_scale to sdpa (#12619)

This commit is contained in:
Yishuo Wang 2024-12-26 16:58:09 +08:00 committed by GitHub
parent 40a7d2b4f0
commit a9abde0b5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
# compute
# import xe_addons
# if is_causal:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
# elif seq_length != kv_length and seq_length <= 32:
# # todo: add scale support
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
# else:
# attn_output = xe_addons.sdp(query, key, value, mask)
# else:
# if key.dtype == torch.uint8:
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
# else:
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
import xe_addons
if is_causal:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask)
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
return attn_output
else: