diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 28cb5efe..ef3da70d 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -42,7 +42,8 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ from ipex_llm.transformers.models.utils import rotate_half, SILU from ipex_llm.transformers.models.utils import mlp_fusion_check from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_new_esimd_sdp_fp16, \ + use_sdp_fp8 from ipex_llm.transformers.models.utils import use_decoding_fast_path from ipex_llm.utils.common import invalidInputError, invalidOperationError from ipex_llm.ggml.quantize import ggml_tensor_qtype @@ -290,11 +291,9 @@ def qwen_attention_forward_original( attn_output = attn_output.transpose(1, 2) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query, - key, - value) + use_new_esimd_sdp_fp16(q_len, key.shape[2], self.head_dim, query): + import linear_q4_0 + attn_output = linear_q4_0.sdp_fp16(query, key, value, attention_mask) attn_output = attn_output.view(query.shape) attn_output = attn_output.transpose(1, 2) attn_weight = None @@ -485,7 +484,7 @@ def qwen_attention_forward_quantized( def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): - if query.size(2) != 1 or query.device.type != 'xpu': + if not use_sdp_fp8(query.size(2), key.size(2), query): # We have no CPU fp8 matmul implementation for now, so just upscale to fp32 key, value = restore_fp8_kv_cache(key, value, query.dtype) attn_weights = torch.matmul(query, key.transpose(-1, -2)) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index b5108ae4..1494a657 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -380,13 +380,6 @@ def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states): # Use new sdp_fp16 only when q_len <= 32 return False - device_name = torch.xpu.get_device_name(query_states.device.index) - if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \ - and is_deepspeed_available: - # It seems there is an issue in DeepSpeed AutoTP when multi-card inference, - # Disable new sdp_fp16 for now - return False - return True