diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 898a673a..b4f5604c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -23,6 +23,7 @@ from typing import Optional, Tuple, List import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.utils import use_flash_attention from bigdl.llm.transformers.models.llama import get_ipex_version @@ -365,11 +366,10 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio pytorch_major_version = int(torch.__version__.split('.')[0]) if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): query_layer = query_layer.permute(1, 2, 0, 3) - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + if attention_mask is None and use_flash_attention(query_layer): context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask, is_causal=True) elif attention_mask is None: scaling_factor = 1 / math.sqrt(query_layer.size(-1)) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 12a5d21d..d4e9e223 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -42,6 +42,7 @@ from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -508,61 +509,6 @@ def llama_attention_selective_batching_forward_4_31( return attn_output.to(original_dtype), attn_weights, updated_past_key_values -def use_flash_attention(query): - bsz, q_len, _ = query.size() - # check whether ipex flash attention can be used - if bsz > 1: - # only use flash attention for batch_size = 1 now - # as flash attention doesn't support attn_mask in ipex 2.1, - # so it will cause output error for padded batch input - return False - if q_len == 1: - # now only use flash attention for first token - # as it seems have no performance benifit for rest token now - return False - if query.device.type != "xpu": - # ipex flash attention only support for xpu - return False - ipex_version = get_ipex_version() - if ipex_version <= "2.0.110+xpu": - # ipex flash attention is supported from ipex 2.1 - return False - if not torch.xpu.has_xetla(): - # ipex flash attention is only supported for xetla - # may update this later - return False - if query.dtype not in [torch.float32, torch.float16]: - # only use flash attention for fp32/fp16 input - return False - return True - - -def use_esimd_sdp(q_len, head_dim, query_states): - if head_dim != 128: - # esimd_sdp only support head_dim = 128 now - return False - elif q_len != 1: - # esimd_sdp only support rest token now - return False - elif query_states.device.type != "xpu": - # esimd_sdp only support GPU now - return False - elif query_states.dtype != torch.float16: - # esimd_sdp only has optimization for FP16 now - return False - else: - device_name = torch.xpu.get_device_name(query_states.device.index) - if device_name.startswith("Intel(R) Arc(TM) A") or \ - device_name.startswith("Intel(R) Data Center GPU Flex"): - import linear_fp16_esimd - if hasattr(linear_fp16_esimd, "sdp_forward"): - return True - else: - return False - else: - return False - - def native_sdp(query, key, value, attention_mask, bsz, q_len, kv_seq_len, head_dim, num_heads): attn_weights = torch.matmul(query, diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 15db1102..502fabfa 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -16,6 +16,7 @@ import torch from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.utils import get_ipex_version def init_kv_cache(batch_size, num_heads, head_dim, current_length, max_length, dtype, device): @@ -119,3 +120,61 @@ def is_enough_kv_cache_room_4_31(past_key_value): # to determinate if is enough kv cache room in transformers between 4.31 and 4.35 return past_key_value is not None and \ past_key_value[0].stride()[1] > past_key_value[0].size(2) * past_key_value[0].size(3) + + +def use_flash_attention(query): + if query.dim() == 3: + bsz, q_len, _ = query.size() + elif query.dim() == 4: + bsz, _, q_len, _ = query.size() + # check whether ipex flash attention can be used + if bsz > 1: + # only use flash attention for batch_size = 1 now + # as flash attention doesn't support attn_mask in ipex 2.1, + # so it will cause output error for padded batch input + return False + if q_len == 1: + # now only use flash attention for first token + # as it seems have no performance benifit for rest token now + return False + if query.device.type != "xpu": + # ipex flash attention only support for xpu + return False + ipex_version = get_ipex_version() + if ipex_version <= "2.0.110+xpu": + # ipex flash attention is supported from ipex 2.1 + return False + if not torch.xpu.has_xetla(): + # ipex flash attention is only supported for xetla + # may update this later + return False + if query.dtype not in [torch.float32, torch.float16]: + # only use flash attention for fp32/fp16 input + return False + return True + + +def use_esimd_sdp(q_len, head_dim, query_states): + if head_dim != 128: + # esimd_sdp only support head_dim = 128 now + return False + elif q_len != 1: + # esimd_sdp only support rest token now + return False + elif query_states.device.type != "xpu": + # esimd_sdp only support GPU now + return False + elif query_states.dtype != torch.float16: + # esimd_sdp only has optimization for FP16 now + return False + else: + device_name = torch.xpu.get_device_name(query_states.device.index) + if device_name.startswith("Intel(R) Arc(TM) A") or \ + device_name.startswith("Intel(R) Data Center GPU Flex"): + import linear_fp16_esimd + if hasattr(linear_fp16_esimd, "sdp_forward"): + return True + else: + return False + else: + return False