support passing attn_scale to sdpa (#12619)
This commit is contained in:
parent
40a7d2b4f0
commit
a9abde0b5d
1 changed files with 5 additions and 22 deletions
|
|
@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
|||
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
||||
|
||||
# compute
|
||||
# import xe_addons
|
||||
# if is_causal:
|
||||
# if key.dtype == torch.uint8:
|
||||
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||
# else:
|
||||
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||
# elif seq_length != kv_length and seq_length <= 32:
|
||||
# # todo: add scale support
|
||||
# if key.dtype == torch.uint8:
|
||||
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
# else:
|
||||
# attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
# else:
|
||||
# if key.dtype == torch.uint8:
|
||||
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||
# else:
|
||||
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||
|
||||
import xe_addons
|
||||
if is_causal:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_causal(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||
elif seq_length != kv_length and seq_length <= 32:
|
||||
# todo: add scale support
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
else:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
|
||||
else:
|
||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||
|
||||
return attn_output
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue