LLM: update esimd sdp kernel (#9871)

This commit is contained in:
Ruonan Wang 2024-01-09 18:10:01 +08:00 committed by GitHub
parent 023679459e
commit 3e05c9e11b

View file

@ -276,8 +276,8 @@ def llama_attention_forward_4_31(
elif use_esimd_sdp(q_len, self.head_dim, query_states): elif use_esimd_sdp(q_len, self.head_dim, query_states):
import linear_fp16_esimd import linear_fp16_esimd
attn_output = linear_fp16_esimd.sdp_forward(query_states, attn_output = linear_fp16_esimd.sdp_forward(query_states,
key_states.contiguous(), key_states,
value_states.contiguous()) value_states)
attn_output = attn_output.view(query_states.shape) attn_output = attn_output.view(query_states.shape)
attn_weights = None attn_weights = None
else: else: