use fp16_sdp when head_dim=96 (#10976)
This commit is contained in:
parent
b7f7d05a7e
commit
e753125880
1 changed files with 1 additions and 1 deletions
|
|
@ -362,7 +362,7 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states):
|
|||
elif query_states.dtype != torch.float16:
|
||||
# esimd_sdp only has optimization for FP16 now
|
||||
return False
|
||||
elif head_dim != 128 and head_dim != 64:
|
||||
elif head_dim not in [64, 96, 128]:
|
||||
# esimd_sdp only support head_dim = 128 and 64 now
|
||||
return False
|
||||
elif q_len == k_len:
|
||||
|
|
|
|||
Loading…
Reference in a new issue