diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 41d11815..4e879603 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -46,7 +46,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \ + use_sdp_fp8 from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import use_decoding_fast_path from transformers.modeling_outputs import BaseModelOutputWithPast @@ -449,7 +450,7 @@ def llama_attention_forward_4_31_quantized( kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) - if query_states.size(2) != 1 or query_states.device.type != 'xpu': + if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) # repeat k/v heads if n_kv_heads < n_heads @@ -666,7 +667,7 @@ def llama_attention_forward_4_31_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask): + use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): import linear_q4_0 attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) @@ -1074,7 +1075,7 @@ def llama_attention_forward_4_36_quantized( self.layer_idx, cache_kwargs, new_layout=True) kv_seq_len = key_states.shape[-2] - if query_states.size(2) != 1 or query_states.device.type != 'xpu': + if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) key_states = repeat_kv(key_states, self.num_key_value_groups)\ @@ -1342,7 +1343,7 @@ def llama_attention_forward_4_36_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): import linear_q4_0 attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 75749130..fd7509db 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -52,7 +52,8 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \ + use_sdp_fp8 from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv @@ -310,7 +311,7 @@ def mistral_attention_forward_quantized( kv_seq_len = key_states.shape[-2] past_key_value = (key_states, value_states) - if query_states.size(2) != 1 or query_states.device.type != 'xpu': + if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) @@ -503,7 +504,7 @@ def mistral_attention_forward_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import linear_q4_0 attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) @@ -687,7 +688,7 @@ def mistral_attention_forward_4_36_quantized( self.layer_idx, cache_kwargs, new_layout=True) kv_seq_len = key_states.shape[-2] - if query_states.size(2) != 1 or query_states.device.type != 'xpu': + if not use_sdp_fp8(q_len, key_states.shape[2], query_states): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) @@ -896,7 +897,7 @@ def mistral_attention_forward_4_36_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import linear_q4_0 attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 2be095a0..b5108ae4 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -325,7 +325,7 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): # esimd_sdp only support head_dim = 128 now return False elif q_len != 1: - # esimd_sdp only support rest token now + # esimd_sdp only support rest token and q_len == 1 now return False elif k_len < 8: # esimd_sdp will cause wrong output when k_len < 8 @@ -363,6 +363,42 @@ def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): return True +def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states): + if query_states.device.type != "xpu": + # esimd_sdp only support GPU now + return False + elif query_states.dtype != torch.float16: + # esimd_sdp only has optimization for FP16 now + return False + elif head_dim != 128 and head_dim != 64: + # esimd_sdp only support head_dim = 128 and 64 now + return False + elif q_len == k_len: + # new sdp_fp16 only support rest token now + return False + elif q_len > 32: + # Use new sdp_fp16 only when q_len <= 32 + return False + + device_name = torch.xpu.get_device_name(query_states.device.index) + if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \ + and is_deepspeed_available: + # It seems there is an issue in DeepSpeed AutoTP when multi-card inference, + # Disable new sdp_fp16 for now + return False + + return True + + +def use_sdp_fp8(q_len, k_len, query_states): + if query_states.device.type != "xpu": + return False + if q_len == k_len: + # sdp_fp8 only support rest token now + return False + return True + + def mlp_fusion_check(x, qtype, training): invalidInputError(x.dim() == 2, "Here input x's dim should be 2.")