diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index 585fdb89..c99c51ec 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -41,7 +41,7 @@ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_x from ipex_llm.transformers.models.utils import mlp_fusion_check, GELU from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36, rotate_half from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 -from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check +from ipex_llm.transformers.models.utils import use_decoding_fast_path import os @@ -77,11 +77,6 @@ def should_use_fuse_rope(self, hidden_states, position_ids): return use_fuse_rope -def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): - return decoding_fast_path_qtype_check(proj) and \ - use_fuse_rope and enough_kv_room and bs == 1 - - def gemma_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 6c170bbf..b9308bc6 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -48,6 +48,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.utils import mlp_fusion_check, fp16_fusion_check +from ipex_llm.transformers.models.utils import use_decoding_fast_path from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.llama.modeling_llama import LlamaModel from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS, FP4 @@ -362,11 +363,12 @@ def llama_attention_forward_4_31_quantized( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) - qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and qtype_check and use_fuse_rope - and enough_kv_room and bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len, + llama_decoding_fast_path_qtype_check) and no_tp # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding @@ -496,11 +498,12 @@ def llama_attention_forward_4_31_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) - qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and - enough_kv_room and bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len, + llama_decoding_fast_path_qtype_check) and no_tp # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding @@ -728,11 +731,12 @@ def llama_attention_selective_batching_forward_4_31( # TODO: decoding fast path use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = past_key_value is not None and is_enough_kv_cache_room_4_31(past_key_value[0]) - qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and - bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len, + llama_decoding_fast_path_qtype_check) and no_tp updated_past_key_values = [] # single batch decoding fast path @@ -948,11 +952,12 @@ def llama_attention_forward_4_36_quantized( device = hidden_states.device use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) - qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and qtype_check and use_fuse_rope - and enough_kv_room and bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len, + llama_decoding_fast_path_qtype_check) and no_tp if decoding_fast_path: hidden_states = hidden_states.view(1, -1) tmp_cache_k, tmp_cache_v = init_kv_cache( @@ -1144,11 +1149,12 @@ def llama_attention_forward_4_36_original( use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len) - qtype_check = llama_decoding_fast_path_qtype_check(self.q_proj) no_tp = not self.config.pretraining_tp > 1 - decoding_fast_path = (no_tp and qtype_check and use_fuse_rope and - enough_kv_room and bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len, + llama_decoding_fast_path_qtype_check) and no_tp # single batch decoding fast path # forward_qkv takes will perform QKV projection, rotary position embedding diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index c43de710..75749130 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -53,6 +53,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ is_enough_kv_cache_room_4_36 from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5, IQ2_XXS from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.llama import llama_decoding_fast_path_qtype_check from ipex_llm.transformers.models.llama import should_use_xetla_mm_qkv from ipex_llm.transformers.models.llama import fuse_qkv_weight_xetla @@ -87,12 +88,6 @@ def should_use_fuse_rope(self, hidden_states, position_ids): return use_fuse_rope -def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): - return llama_decoding_fast_path_qtype_check(proj) and \ - use_fuse_rope and enough_kv_room and bs == 1 and \ - not proj.enable_xetla - - def compute_attn_outputs_weights(query_states, key_states, value_states, bsz, q_len, kv_seq_len, num_heads, head_dim, hidden_size, attention_mask): attn_weights = torch.matmul( diff --git a/python/llm/src/ipex_llm/transformers/models/mixtral.py b/python/llm/src/ipex_llm/transformers/models/mixtral.py index 4f069d90..01c420bd 100644 --- a/python/llm/src/ipex_llm/transformers/models/mixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/mixtral.py @@ -53,7 +53,8 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_cache_freq_xpu, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path +from ipex_llm.transformers.models.mistral import should_use_fuse_rope +from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.low_bit_linear import IQ2_XXS @@ -177,9 +178,8 @@ def mixtral_attention_forward( use_fuse_rope, enough_kv_room, bsz * q_len) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla - if decoding_fast_path and self.q_proj.qtype != IQ2_XXS: + if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k = past_key_value.key_cache[self.layer_idx] cache_v = past_key_value.value_cache[self.layer_idx] diff --git a/python/llm/src/ipex_llm/transformers/models/phixtral.py b/python/llm/src/ipex_llm/transformers/models/phixtral.py index b79c37f4..d1029985 100644 --- a/python/llm/src/ipex_llm/transformers/models/phixtral.py +++ b/python/llm/src/ipex_llm/transformers/models/phixtral.py @@ -48,7 +48,7 @@ from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36 -from ipex_llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path +from ipex_llm.transformers.models.mistral import should_use_fuse_rope from ipex_llm.transformers.models.utils import use_flash_attention from ipex_llm.transformers.models.utils import mlp_fusion_check diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index a02db2bd..28cb5efe 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -43,7 +43,7 @@ from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp -from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check +from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.utils.common import invalidInputError, invalidOperationError from ipex_llm.ggml.quantize import ggml_tensor_qtype from transformers.modeling_outputs import BaseModelOutputWithPast @@ -142,9 +142,10 @@ def qwen_attention_forward_original( rotary_pos_emb_list = rotary_pos_emb_list[:-1] use_fuse_rope = should_use_fuse_rope(self, hidden_states) - qtype_check = decoding_fast_path_qtype_check(self.q_proj) - decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + True, + bsz * q_len) if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k, cache_v = layer_past[0], layer_past[1] diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index e411f691..9cf7a640 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -57,7 +57,7 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_po from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast -from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check +from ipex_llm.transformers.models.utils import use_decoding_fast_path try: from transformers.cache_utils import Cache, DynamicCache @@ -435,9 +435,10 @@ def qwen2_attention_forward_origin( device = hidden_states.device enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - qtype_check = decoding_fast_path_qtype_check(self.q_proj) - decoding_fast_path = (qtype_check and use_fuse_rope - and enough_kv_room and bsz * q_len == 1) + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len) if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k = past_key_value.key_cache[self.layer_idx] @@ -604,9 +605,10 @@ def qwen2_sdpa_attention_forward( device = hidden_states.device enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) - qtype_check = decoding_fast_path_qtype_check(self.q_proj) - decoding_fast_path = (qtype_check and use_fuse_rope - and enough_kv_room and bsz * q_len == 1) + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + enough_kv_room, + bsz * q_len) if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k = past_key_value.key_cache[self.layer_idx] diff --git a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py index 869dc052..94b57297 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen_vl.py @@ -33,7 +33,7 @@ from transformers.utils import logging from ipex_llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import rotate_half from ipex_llm.transformers.models.utils import use_esimd_sdp -from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check +from ipex_llm.transformers.models.utils import use_decoding_fast_path import os @@ -91,9 +91,10 @@ def qwen_attention_forward_vl( device = hidden_states.device use_fuse_rope = should_use_fuse_rope(self, hidden_states) - qtype_check = decoding_fast_path_qtype_check(self.q_proj) - decoding_fast_path = (qtype_check and use_fuse_rope and bsz * q_len == 1) - decoding_fast_path = decoding_fast_path and not self.q_proj.enable_xetla + decoding_fast_path = use_decoding_fast_path(self.q_proj, + use_fuse_rope, + True, + bsz * q_len) if decoding_fast_path: hidden_states = hidden_states.view(1, -1) cache_k, cache_v = layer_past[0], layer_past[1] diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 8f3b98a8..354b4831 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -369,6 +369,27 @@ def mlp_fusion_check(x, qtype, training): return True +def use_decoding_fast_path(proj, + use_fuse_rope, + enough_kv_room, + bs, + qtype_check=decoding_fast_path_qtype_check): + device = get_xpu_device_type(proj.weight) + if not qtype_check(proj): + return False + if not use_fuse_rope: + return False + if not enough_kv_room: + return False + if bs != 1: + return False + if proj.enable_xetla: + return False + if device in ["uhd"]: + return False + return True + + def use_xmx(x: torch.Tensor, qtype: int): device = get_xpu_device_type(x) return ( diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index f0e8f9cc..ad753710 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -35,20 +35,12 @@ from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, a from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_31, SILU -from ipex_llm.transformers.low_bit_linear import SYM_INT4, FP8E5 -from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check import os KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) -def use_decoding_fast_path(proj, use_fuse_rope, enough_kv_room, bs): - return decoding_fast_path_qtype_check(proj) and \ - use_fuse_rope and enough_kv_room and bs == 1 \ - and not proj.enable_xetla - - def should_use_fuse_rope(self, hidden_states, position_ids): use_fuse_rope = hidden_states.device.type == "xpu" use_fuse_rope = use_fuse_rope and not (self.training and hidden_states.requires_grad) diff --git a/python/llm/src/ipex_llm/transformers/utils.py b/python/llm/src/ipex_llm/transformers/utils.py index 8d2fbedc..49894a2b 100644 --- a/python/llm/src/ipex_llm/transformers/utils.py +++ b/python/llm/src/ipex_llm/transformers/utils.py @@ -180,6 +180,8 @@ def get_xpu_device_type(x): return "flex" elif name.startswith("Intel(R) Data Center GPU Max"): return "pvc" + elif name.startswith("Intel(R) UHD"): + return "uhd" else: return "others"