diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 179ec18b..26239eff 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -284,7 +284,7 @@ def llama_attention_forward_4_31( value_states, is_causal=True) attn_weights = None - elif use_esimd_sdp(q_len, self.head_dim, query_states): + elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, key_states, @@ -689,11 +689,11 @@ def llama_attention_forward_4_36( value_states, is_causal=True) attn_weights = None - elif use_esimd_sdp(q_len, self.head_dim, query_states): + elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states.contiguous(), - value_states.contiguous()) + key_states, + value_states) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 620f9102..5abd1345 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -230,13 +230,16 @@ def use_flash_attention(query, key): return True -def use_esimd_sdp(q_len, head_dim, query_states): +def use_esimd_sdp(q_len, k_len, head_dim, query_states): if head_dim != 128: # esimd_sdp only support head_dim = 128 now return False elif q_len != 1: # esimd_sdp only support rest token now return False + elif k_len < 8: + # esimd_sdp will cause wrong output when k_len < 8 + return False elif query_states.device.type != "xpu": # esimd_sdp only support GPU now return False