LLM: update esimd sdp kernel (#9871)
This commit is contained in:
parent
023679459e
commit
3e05c9e11b
1 changed files with 2 additions and 2 deletions
|
|
@ -276,8 +276,8 @@ def llama_attention_forward_4_31(
|
||||||
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
elif use_esimd_sdp(q_len, self.head_dim, query_states):
|
||||||
import linear_fp16_esimd
|
import linear_fp16_esimd
|
||||||
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
attn_output = linear_fp16_esimd.sdp_forward(query_states,
|
||||||
key_states.contiguous(),
|
key_states,
|
||||||
value_states.contiguous())
|
value_states)
|
||||||
attn_output = attn_output.view(query_states.shape)
|
attn_output = attn_output.view(query_states.shape)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue