add sdp_causal for mistral 4.36 (#11686)
* add sdp_causal for mistral * fix * update
This commit is contained in:
parent
45c730ff39
commit
736a7ef72e
2 changed files with 31 additions and 7 deletions
|
|
@ -1143,7 +1143,16 @@ def llama_attention_forward_4_41_quantized(
|
||||||
if len(past_key_value.key_cache) <= self.layer_idx:
|
if len(past_key_value.key_cache) <= self.layer_idx:
|
||||||
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
repeated_key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
if should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
if use_cache:
|
||||||
|
cache_kwargs = None
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
|
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
||||||
|
query_states, self.training):
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
|
value_states, attention_mask)
|
||||||
|
elif should_split_qkv_tensor(query_states, bsz, self.num_heads,
|
||||||
q_len, kv_seq_len, output_attentions):
|
q_len, kv_seq_len, output_attentions):
|
||||||
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states,
|
||||||
repeated_value_states,
|
repeated_value_states,
|
||||||
|
|
@ -1184,10 +1193,6 @@ def llama_attention_forward_4_41_quantized(
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||||
if use_cache:
|
|
||||||
cache_kwargs = None
|
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
|
||||||
self.layer_idx, cache_kwargs)
|
|
||||||
else:
|
else:
|
||||||
cache_kwargs = None # Specific to RoPE models
|
cache_kwargs = None # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,8 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \
|
||||||
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
is_enough_kv_cache_room_4_36
|
is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \
|
||||||
|
use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||||
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check
|
||||||
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv
|
||||||
|
|
@ -599,6 +600,15 @@ def mistral_attention_forward_original(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
elif use_sdp_causal(q_len, key_states.shape[2], self.head_dim,
|
||||||
|
query_states, self.training):
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
|
||||||
|
value_states.contiguous(), attention_mask)
|
||||||
|
attn_output = attn_output.view(query_states.shape)
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
# new fp16 sdp doesn't require repeat_kv
|
# new fp16 sdp doesn't require repeat_kv
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
@ -1052,6 +1062,15 @@ def mistral_attention_forward_4_36_original(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
elif use_sdp_causal(q_len, key_states.shape[2], self.head_dim,
|
||||||
|
query_states, self.training):
|
||||||
|
import xe_addons
|
||||||
|
attn_output = xe_addons.sdp_causal(query_states, key_states.contiguous(),
|
||||||
|
value_states.contiguous(), attention_mask)
|
||||||
|
attn_output = attn_output.view(query_states.shape)
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
|
||||||
# new fp16 sdp doesn't require repeat_kv
|
# new fp16 sdp doesn't require repeat_kv
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue