support passing attn_scale to sdpa (#12619)
This commit is contained in:
parent
40a7d2b4f0
commit
a9abde0b5d
1 changed files with 5 additions and 22 deletions
|
|
@ -237,40 +237,23 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
||||||
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)
|
||||||
|
|
||||||
# compute
|
# compute
|
||||||
# import xe_addons
|
|
||||||
# if is_causal:
|
|
||||||
# if key.dtype == torch.uint8:
|
|
||||||
# attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
|
||||||
# else:
|
|
||||||
# attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
|
||||||
# elif seq_length != kv_length and seq_length <= 32:
|
|
||||||
# # todo: add scale support
|
|
||||||
# if key.dtype == torch.uint8:
|
|
||||||
# attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
|
||||||
# else:
|
|
||||||
# attn_output = xe_addons.sdp(query, key, value, mask)
|
|
||||||
# else:
|
|
||||||
# if key.dtype == torch.uint8:
|
|
||||||
# attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
|
||||||
# else:
|
|
||||||
# attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
|
||||||
|
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if is_causal:
|
if is_causal:
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask)
|
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query, key, value, mask)
|
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||||
elif seq_length != kv_length and seq_length <= 32:
|
elif seq_length != kv_length and seq_length <= 32:
|
||||||
|
# todo: add scale support
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||||
else:
|
else:
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask)
|
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_non_causal(query, key, value, mask)
|
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue