update sdp support (#12847)

This commit is contained in:
Yishuo Wang 2025-02-19 12:07:00 +08:00 committed by GitHub
parent 93c10be762
commit aee2db30f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -230,7 +230,7 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
if (
device.type == "xpu"
and dtype in [torch.float, torch.half]
and head_dim in [64, 80, 96, 128]
and head_dim in [64, 80, 96, 128, 192, 256]
):
# prepare scale
scale = 1 / math.sqrt(head_dim) if scale is None else scale