LLM: limit esimd sdp usage for k_len < 8 (#9959)

* update

* fix
This commit is contained in:
Ruonan Wang 2024-01-23 09:28:23 +08:00 committed by GitHub
parent 3e601f9a5d
commit 7b1d9ad7c0
2 changed files with 8 additions and 5 deletions

View file

@ -284,7 +284,7 @@ def llama_attention_forward_4_31(
value_states, value_states,
is_causal=True) is_causal=True)
attn_weights = None attn_weights = None
elif use_esimd_sdp(q_len, self.head_dim, query_states): elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states, attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states, key_states,
@ -689,11 +689,11 @@ def llama_attention_forward_4_36(
value_states, value_states,
is_causal=True) is_causal=True)
attn_weights = None attn_weights = None
elif use_esimd_sdp(q_len, self.head_dim, query_states): elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
import linear_fp16_esimd import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states, attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states.contiguous(), key_states,
value_states.contiguous()) value_states)
attn_output = attn_output.view(query_states.shape) attn_output = attn_output.view(query_states.shape)
attn_weights = None attn_weights = None
else: else:

View file

@ -230,13 +230,16 @@ def use_flash_attention(query, key):
return True return True
def use_esimd_sdp(q_len, head_dim, query_states): def use_esimd_sdp(q_len, k_len, head_dim, query_states):
if head_dim != 128: if head_dim != 128:
# esimd_sdp only support head_dim = 128 now # esimd_sdp only support head_dim = 128 now
return False return False
elif q_len != 1: elif q_len != 1:
# esimd_sdp only support rest token now # esimd_sdp only support rest token now
return False return False
elif k_len < 8:
# esimd_sdp will cause wrong output when k_len < 8
return False
elif query_states.device.type != "xpu": elif query_states.device.type != "xpu":
# esimd_sdp only support GPU now # esimd_sdp only support GPU now
return False return False