parent
3e601f9a5d
commit
7b1d9ad7c0
2 changed files with 8 additions and 5 deletions
|
|
@ -284,7 +284,7 @@ def llama_attention_forward_4_31(
|
|||
value_states,
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states,
|
||||
|
|
@ -689,11 +689,11 @@ def llama_attention_forward_4_36(
|
|||
value_states,
|
||||
is_causal=True)
|
||||
attn_weights = None
|
||||
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
||||
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||
import linear_fp16_esimd
|
||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||
key_states.contiguous(),
|
||||
value_states.contiguous())
|
||||
key_states,
|
||||
value_states)
|
||||
attn_output = attn_output.view(query_states.shape)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -230,13 +230,16 @@ def use_flash_attention(query, key):
|
|||
return True
|
||||
|
||||
|
||||
def use_esimd_sdp(q_len, head_dim, query_states):
|
||||
def use_esimd_sdp(q_len, k_len, head_dim, query_states):
|
||||
if head_dim != 128:
|
||||
# esimd_sdp only support head_dim = 128 now
|
||||
return False
|
||||
elif q_len != 1:
|
||||
# esimd_sdp only support rest token now
|
||||
return False
|
||||
elif k_len < 8:
|
||||
# esimd_sdp will cause wrong output when k_len < 8
|
||||
return False
|
||||
elif query_states.device.type != "xpu":
|
||||
# esimd_sdp only support GPU now
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue