From 59df75032693e5ed14d1c7f4143462797a0d1330 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 16 May 2024 09:33:34 +0800 Subject: [PATCH] Use new sdp again (#11025) --- .../ipex_llm/transformers/models/baichuan2.py | 22 +++---- .../ipex_llm/transformers/models/chatglm2.py | 33 +++++----- .../src/ipex_llm/transformers/models/phi3.py | 6 +- .../ipex_llm/transformers/models/qwen2_moe.py | 2 +- .../src/ipex_llm/transformers/models/utils.py | 66 +------------------ 5 files changed, 34 insertions(+), 95 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan2.py b/python/llm/src/ipex_llm/transformers/models/baichuan2.py index e0c76a5d..a1b9ddbf 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan2.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan2.py @@ -28,10 +28,11 @@ from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_ restore_fp8_kv_cache, use_quantize_kv_cache from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, SILU from ipex_llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from ipex_llm.transformers.models.utils import mlp_fusion_check +from ipex_llm.utils.common.log4Error import invalidInputError from transformers.utils import logging logger = logging.get_logger(__name__) @@ -166,9 +167,8 @@ def baichuan_attention_forward_7b_quantized( past_key_value = (key_states, value_states) if use_cache else None - if attention_mask is not None: - if attention_mask.dtype == torch.bool: - attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) + invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool, + "attention_mask's dtype cannot be bool") scaling_factor = 1 / math.sqrt(query_states.size(-1)) if query_states.size(2) != 1 or device.type != 'xpu': @@ -279,6 +279,9 @@ def baichuan_attention_forward_7b_origin( past_key_value = (key_states, value_states) if use_cache else None + invalidInputError(attention_mask is None or attention_mask.dtype != torch.bool, + "attention_mask's dtype cannot be bool") + if xops is not None and self.training: attn_weights = None query_states = query_states.transpose(1, 2) @@ -296,17 +299,12 @@ def baichuan_attention_forward_7b_origin( is_causal=True) attn_weights = None elif not self.training and not hidden_states.requires_grad and \ - use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_states, - key_states, - value_states) + use_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) attn_output = attn_output.view(query_states.shape) attn_weights = None else: - if attention_mask is not None: - if attention_mask.dtype == torch.bool: - attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) if should_split_qkv_tensor(query_states, bsz, self.num_heads, q_len, kv_seq_len, output_attentions): attn_output, attn_weights = native_sdp_split_qkv_tensor(query_states, diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 1367e3a3..8c11c3aa 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -25,7 +25,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from ipex_llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ restore_fp8_kv_cache, use_quantize_kv_cache -from ipex_llm.transformers.models.utils import use_esimd_sdp +from ipex_llm.transformers.models.utils import use_sdp import os @@ -558,25 +558,28 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask value_layer, is_causal=True).to(key_layer.dtype) else: - if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2], - query_layer.shape[-1], query_layer): - import linear_fp16_esimd - attn_output = linear_fp16_esimd.sdp_forward(query_layer, - key_layer, - value_layer) + # attention_mask is not None only when past_key_value is not None and q_len > 1 + if attention_mask is not None: + attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, + device=query_layer.device) + attention_mask = ~attention_mask + if attention_mask.dtype == torch.bool: + attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) + else: + attn_bias += attention_mask + else: + attn_bias = None + + if use_sdp(query_layer.shape[2], key_layer.shape[2], + query_layer.shape[-1], query_layer): + import linear_q4_0 + attn_output = linear_q4_0.sdp(query_layer, key_layer, value_layer, attn_bias) context_layer = attn_output.view(query_layer.shape) else: head_dim = query_layer.size(-1) attn = torch.matmul(query_layer.to(key_layer.dtype), key_layer.transpose(2, 3)) / math.sqrt(head_dim) - if attention_mask is not None: - attn_bias = torch.zeros(attention_mask.shape, dtype=query_layer.dtype, - device=query_layer.device) - attention_mask = ~attention_mask - if attention_mask.dtype == torch.bool: - attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf")) - else: - attn_bias += attention_mask + if attn_bias is not None: attn += attn_bias attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(value_layer.dtype) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index d1451666..6593e81b 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -41,8 +41,8 @@ from ipex_llm.transformers.models.utils import ( apply_rotary_pos_emb_cache_freq_xpu ) from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU -from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, use_quantize_kv_cache -from ipex_llm.transformers.models.utils import use_sdp_fp8, restore_fp8_kv_cache +from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache from typing import Optional, Tuple, List @@ -144,7 +144,7 @@ def attention_forward( attention_mask) else: attn_output = linear_q4_0.sdp(query_states, key_states, value_states, attention_mask) - elif use_sdp_causal(q_len, kv_seq_len, query_states, self.training): + elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training): import linear_q4_0 if isinstance(past_key_value, DynamicFp8Cache): attn_output = linear_q4_0.sdp_fp8_causal(query_states, key_states, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index c7b242c7..c136c8f7 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -52,7 +52,7 @@ from ipex_llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb from ipex_llm.utils.common import invalidInputError from ipex_llm.transformers.models.utils import decoding_fast_path_qtype_check -from ipex_llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from ipex_llm.transformers.models.utils import use_flash_attention from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeModel, apply_rotary_pos_emb from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache from ipex_llm.transformers.kv import DynamicFp8Cache diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index b899370d..22320a4b 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -318,69 +318,6 @@ def use_flash_attention(query, key, attention_mask=None): return True -def use_esimd_sdp(q_len, k_len, head_dim, query_states, attention_mask=None): - if head_dim != 128: - # esimd_sdp only support head_dim = 128 now - return False - elif q_len != 1: - # esimd_sdp only support rest token and q_len == 1 now - return False - elif k_len < 8: - # esimd_sdp will cause wrong output when k_len < 8 - return False - elif query_states.device.type != "xpu": - # esimd_sdp only support GPU now - return False - elif query_states.dtype != torch.float16: - # esimd_sdp only has optimization for FP16 now - return False - - device_name = torch.xpu.get_device_name(query_states.device.index) - if device_name.startswith("Intel(R) Arc(TM) A") or \ - device_name.startswith("Intel(R) Data Center GPU Flex") or \ - device_name.startswith("Intel(R) Data Center GPU Max"): - import linear_fp16_esimd - if not hasattr(linear_fp16_esimd, "sdp_forward"): - return False - else: - return False - - if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Data Center GPU Max"): - # esimd_sdp not support PVC GPU when batch size > 1 for now - return False - if query_states.shape[0] > 1 and device_name.startswith("Intel(R) Arc(TM) A") \ - and is_deepspeed_available: - # esimd_sdp not support ARC GPU when batch size > 1 using DeepSpeed AutoTP for now - return False - if query_states.shape[0] > 1 and attention_mask is not None: - # for batched input, can't accept attention_mask - # TODO: this check needs some time - if not torch.all(attention_mask.eq(0)): - return False - - return True - - -def use_new_esimd_sdp_fp16(q_len, k_len, head_dim, query_states): - if query_states.device.type != "xpu": - # esimd_sdp only support GPU now - return False - elif query_states.dtype != torch.float16: - # esimd_sdp only has optimization for FP16 now - return False - elif head_dim not in [64, 96, 128]: - # esimd_sdp only support head_dim = 128 and 64 now - return False - elif q_len == k_len: - # new sdp_fp16 only support rest token now - return False - elif q_len > 32: - # Use new sdp_fp16 only when q_len <= 32 - return False - - return True - - def use_sdp(q_len, kv_len, head_dim, query_states): return ( query_states.device.type == "xpu" @@ -400,9 +337,10 @@ def use_sdp_fp8(q_len, kv_len, query_states): ) -def use_sdp_causal(q_len, kv_len, query_states, training): +def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): return ( q_len == kv_len # first token + and head_dim in [64, 96, 128] # for now and query_states.device.type == "xpu" # GPU and query_states.dtype in [torch.float, torch.half] # fp32/fp16 and not query_states.requires_grad and not training # not training