From e7531258804a72b6dbba9502f9b0dfd1fc31a627 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 9 May 2024 17:02:59 +0800 Subject: [PATCH] use fp16_sdp when head_dim=96 (#10976) --- python/llm/src/ipex_llm/transformers/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 2db6bd1b..21412806 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -362,7 +362,7 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states): elif query_states.dtype != torch.float16: # esimd_sdp only has optimization for FP16 now return False - elif head_dim != 128 and head_dim != 64: + elif head_dim not in [64, 96, 128]: # esimd_sdp only support head_dim = 128 and 64 now return False elif q_len == k_len: