From c7422712fc0d502fb6d516f64a3b3de025d57312 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Tue, 9 Apr 2024 13:50:33 +0800 Subject: [PATCH] mistral 4.36 use fp16 sdp (#10704) --- python/llm/src/ipex_llm/transformers/models/mistral.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index e57ba7f6..4ebe4886 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -896,6 +896,15 @@ def mistral_attention_forward_4_36_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_states, + key_states, + value_states) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) else: attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states,