update sdp support (#12847)
This commit is contained in:
parent
93c10be762
commit
aee2db30f9
1 changed files with 1 additions and 1 deletions
|
|
@ -230,7 +230,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
||||||
if (
|
if (
|
||||||
device.type == "xpu"
|
device.type == "xpu"
|
||||||
and dtype in [torch.float, torch.half]
|
and dtype in [torch.float, torch.half]
|
||||||
and head_dim in [64, 80, 96, 128]
|
and head_dim in [64, 80, 96, 128, 192, 256]
|
||||||
):
|
):
|
||||||
# prepare scale
|
# prepare scale
|
||||||
scale = 1 / math.sqrt(head_dim) if scale is None else scale
|
scale = 1 / math.sqrt(head_dim) if scale is None else scale
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue