use fp16_sdp when head_dim=96 (#10976)

This commit is contained in:
Yishuo Wang 2024-05-09 17:02:59 +08:00 committed by GitHub
parent b7f7d05a7e
commit e753125880
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -362,7 +362,7 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
elif query_states.dtype != torch.float16:
# esimd_sdp only has optimization for FP16 now
return False
elif head_dim != 128 and head_dim != 64:
elif head_dim not in [64, 96, 128]:
# esimd_sdp only support head_dim = 128 and 64 now
return False
elif q_len == k_len: