update sdp support (#12847)
This commit is contained in:
parent
93c10be762
commit
aee2db30f9
1 changed files with 1 additions and 1 deletions
|
|
@ -230,7 +230,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
|||
if (
|
||||
device.type == "xpu"
|
||||
and dtype in [torch.float, torch.half]
|
||||
and head_dim in [64, 80, 96, 128]
|
||||
and head_dim in [64, 80, 96, 128, 192, 256]
|
||||
):
|
||||
# prepare scale
|
||||
scale = 1 / math.sqrt(head_dim) if scale is None else scale
|
||||
|
|
|
|||
Loading…
Reference in a new issue