parent
3e601f9a5d
commit
7b1d9ad7c0
2 changed files with 8 additions and 5 deletions
|
|
@ -284,7 +284,7 @@ def llama_attention_forward_4_31(
|
||||||
value_states,
|
value_states,
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import linear_fp16_esimd
|
import linear_fp16_esimd
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
|
@ -689,11 +689,11 @@ def llama_attention_forward_4_36(
|
||||||
value_states,
|
value_states,
|
||||||
is_causal=True)
|
is_causal=True)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
import linear_fp16_esimd
|
import linear_fp16_esimd
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||||
key_states.contiguous(),
|
key_states,
|
||||||
value_states.contiguous())
|
value_states)
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -230,13 +230,16 @@ def use_flash_attention(query, key):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def use_esimd_sdp(q_len, head_dim, query_states):
|
def use_esimd_sdp(q_len, k_len, head_dim, query_states):
|
||||||
if head_dim != 128:
|
if head_dim != 128:
|
||||||
# esimd_sdp only support head_dim = 128 now
|
# esimd_sdp only support head_dim = 128 now
|
||||||
return False
|
return False
|
||||||
elif q_len != 1:
|
elif q_len != 1:
|
||||||
# esimd_sdp only support rest token now
|
# esimd_sdp only support rest token now
|
||||||
return False
|
return False
|
||||||
|
elif k_len < 8:
|
||||||
|
# esimd_sdp will cause wrong output when k_len < 8
|
||||||
|
return False
|
||||||
elif query_states.device.type != "xpu":
|
elif query_states.device.type != "xpu":
|
||||||
# esimd_sdp only support GPU now
|
# esimd_sdp only support GPU now
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue