diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index 8e54ee55..26509b92 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -27,7 +27,7 @@ from torch import nn import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ @@ -276,11 +276,9 @@ def baichuan_attention_forward_7b_origin( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/ipex_llm/transformers/models/cohere.py b/python/llm/src/ipex_llm/transformers/models/cohere.py index f5404665..1ff8b53b 100644 --- a/python/llm/src/ipex_llm/transformers/models/cohere.py +++ b/python/llm/src/ipex_llm/transformers/models/cohere.py @@ -50,7 +50,7 @@ from ipex_llm.transformers.models.utils import extend_kv_cache, append_kv_cache from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import use_decoding_fast_path -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.kv import DynamicFp8Cache @@ -420,9 +420,13 @@ def cohere_attention_forward_origin( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + else: + causal_mask = None + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, causal_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index e3e06fe9..45c0c4f4 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -46,8 +46,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \ - use_sdp_fp8 +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check from ipex_llm.transformers.models.utils import use_decoding_fast_path from transformers.modeling_outputs import BaseModelOutputWithPast @@ -673,9 +672,9 @@ def llama_attention_forward_4_31_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: @@ -1348,9 +1347,9 @@ def llama_attention_forward_4_36_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index ef1971f2..fcbb6961 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -52,8 +52,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS -from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \ - use_sdp_fp8 +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv @@ -591,10 +590,10 @@ def mistral_attention_forward_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): + elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() @@ -1032,10 +1031,10 @@ def mistral_attention_forward_4_36_original( attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - elif use_new_esimd_sdp_fp16(q_len, key_states.shape[2], self.head_dim, query_states): + elif use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): # new fp16 sdp doesn't require repeat_kv import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index d19f46bb..c2013540 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -55,7 +55,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.mistral import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_decoding_fast_path -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.low_bit_linear import IQ2_XXS @@ -332,12 +332,9 @@ def mixtral_attention_forward( value_states, is_causal=True) attn_weights = None - elif use_esimd_sdp(query_states.shape[2], key_states.shape[2], - self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + elif use_sdp(query_states.shape[2], key_states.shape[2], self.head_dim, query_states): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index dbcce5a4..dba3a4e4 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -40,9 +40,8 @@ from ipex_llm.transformers.models.utils import ( apply_rotary_pos_emb_cache_freq_xpu ) from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU -from ipex_llm.transformers.models.utils import use_new_esimd_sdp_fp16, use_quantize_kv_cache +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache -from ipex_llm.transformers.models.utils import use_sdp_causal from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from typing import Optional, Tuple, List @@ -142,9 +141,9 @@ def attention_forward( import linear_q4_0 attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, attention_mask) elif (isinstance(past_key_value, DynamicNormalCache) and - use_new_esimd_sdp_fp16(q_len, kv_seq_len, self.head_dim, query_states)): + use_sdp(q_len, kv_seq_len, self.head_dim, query_states)): import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query_states, key_states, value_states, attention_mask) + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) else: if isinstance(past_key_value, DynamicFp8Cache): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 7405d552..ffc95552 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -42,8 +42,7 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \ - use_sdp_fp8 +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_fp8 from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.utils.common import invalidInputError, invalidOperationError from ipex_llm.ggml.quantize import ggml_tensor_qtype @@ -291,9 +290,9 @@ def qwen_attention_forward_original( attn_output = attn_output.transpose(1, 2) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query): + use_sdp(q_len, key.shape[2], self.head_dim, query): import linear_q4_0 - attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask) + attn_output = linear_q4_0.sdp(query, key, value, attention_mask) attn_output = attn_output.view(query.shape) attn_output = attn_output.transpose(1, 2) attn_weight = None diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 13523092..d599152e 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.kv import DynamicFp8Cache from ipex_llm.utils.common import invalidInputError -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask @@ -565,11 +565,9 @@ def qwen2_attention_forward_origin( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py index 94b57297..24ac0767 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py @@ -32,7 +32,7 @@ import torch.utils.checkpoint from transformers.utils import logging from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import rotate_half -from ipex_llm.transformers.models.utils import use_esimd_sdp +from ipex_llm.transformers.models.utils import use_sdp from ipex_llm.transformers.models.utils import use_decoding_fast_path import os @@ -207,11 +207,9 @@ def qwen_attention_forward_vl( query = query.permute(0, 2, 1, 3) if not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query, - key, - value) + use_sdp(q_len, key.shape[2], self.head_dim, query): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query, key, value, attention_mask) attn_output = attn_output.view(query.shape) attn_output = attn_output.transpose(1, 2) attn_weight = None diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index b91b40a9..4e8b685f 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -53,7 +53,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, \ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.mistral import should_use_fuse_rope, repeat_kv try: from transformers.cache_utils import Cache @@ -266,11 +266,9 @@ def stablelm_attention_forward_original( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states, attention_mask): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index c7a29bc8..2ca74fa9 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -375,20 +375,31 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states): return True -def use_sdp_fp8(q_len, k_len, query_states): - if query_states.device.type != "xpu": - return False - if q_len == k_len: - # sdp_fp8 only support rest token now - return False - return True +def use_sdp(q_len, kv_len, head_dim, query_states): + return ( + query_states.device.type == "xpu" + and query_states.dtype in [torch.float, torch.half] # fp32/fp16 + and head_dim in [64, 96, 128] + and q_len != kv_len # next token + and q_len <= 32 # lookup + ) + + +def use_sdp_fp8(q_len, kv_len, query_states): + return ( + query_states.device.type == "xpu" + and query_states.dtype in [torch.float, torch.half] # fp32/fp16 + and q_len != kv_len # next token + and q_len <= 32 # lookup + ) def use_sdp_causal(q_len, kv_len, query_states, training): return ( q_len == kv_len # first token and query_states.device.type == "xpu" # GPU - and not query_states.requires_grad and not training # no training + and query_states.dtype in [torch.float, torch.half] # fp32/fp16 + and not query_states.requires_grad and not training # not training )