diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index ae9b7812..2d5597c2 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -46,7 +46,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8, \ + use_sdp_causal from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import use_decoding_fast_path from transformers.modeling_outputs import BaseModelOutputWithPast @@ -1678,8 +1679,17 @@ def llama_attention_forward_4_38_quantized( if len(past_key_value.key_cache) <= self.layer_idx: repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) - if should_split_qkv_tensor(query_states, bsz, self.num_heads, - q_len, kv_seq_len, output_attentions): + if use_cache: + cache_kwargs = None + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim, + query_states, self.training): + import xe_addons + attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, + value_states, attention_mask) + elif should_split_qkv_tensor(query_states, bsz, self.num_heads, + q_len, kv_seq_len, output_attentions): attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, repeated_value_states, attention_mask, cache_position, @@ -1719,10 +1729,6 @@ def llama_attention_forward_4_38_quantized( attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, repeated_value_states) - if use_cache: - cache_kwargs = None - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, cache_kwargs) else: cache_kwargs = None # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states,