diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 2db6bd1b..21412806 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -362,7 +362,7 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states): elif query_states.dtype != torch.float16: # esimd_sdp only has optimization for FP16 now return False - elif head_dim != 128 and head_dim != 64: + elif head_dim not in [64, 96, 128]: # esimd_sdp only support head_dim = 128 and 64 now return False elif q_len == k_len: