From 16433dd95922c51a9b05e2f1c18ae2af904ec0f1 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Fri, 5 Jan 2024 13:49:37 +0800 Subject: [PATCH] LLM: fix first token judgement of flash attention (#9841) * fix flash attention * meet code review * fix --- .../bigdl/llm/transformers/models/chatglm2.py | 2 +- .../src/bigdl/llm/transformers/models/llama.py | 17 +++++++++-------- .../src/bigdl/llm/transformers/models/utils.py | 18 ++++++++++++------ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 08ee571f..f29a0f49 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -367,7 +367,7 @@ def core_attn_forward_8eb45c(self, query_layer, key_layer, value_layer, attentio if pytorch_major_version >= 2 and (query_layer.device.type == 'xpu' or query_layer.size(0) > 1): query_layer = query_layer.permute(1, 2, 0, 3) L, S = query_layer.shape[2], key_layer.shape[2] - if attention_mask is None and (use_flash_attention(query_layer) or + if attention_mask is None and (use_flash_attention(query_layer, key_layer) or L == S and query_layer.device.type == "cpu"): context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index f01d0cd3..a4f6dfc0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -139,14 +139,6 @@ def llama_attention_forward_4_31( device = hidden_states.device # for flash attention original_dtype = hidden_states.dtype - if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(hidden_states) - else: - fsdp_flag = False - if fsdp_flag: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value, seq_len=q_len) @@ -259,6 +251,15 @@ def llama_attention_forward_4_31( past_key_value = (key_states, value_states) if use_cache else None + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, dtype=attention_dtype) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index f009d223..5736bcd1 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -123,18 +123,24 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): (past_key_value[0].size(2) + seq_len - 1) * past_key_value[0].size(3) -def use_flash_attention(query): - if query.dim() == 3: - bsz, q_len, _ = query.size() - elif query.dim() == 4: - bsz, _, q_len, _ = query.size() +def use_flash_attention(query, key): + # here we support query's shape is always [batch_size, head_num, q_len, head_dim], + # key's shape is always [batch_size, head_num, k_len, head_dim] + invalidInputError(query.dim() == 4, + "Here query input of use_flash_attention should be [batch_size, " + "head_num, q_len, head_dim]") + invalidInputError(key.dim() == 4, + "Here key input of use_flash_attention should be [batch_size, " + "head_num, k_len, head_dim]") + bsz, _, q_len, _ = query.size() + k_len = key.size()[2] # check whether ipex flash attention can be used if bsz > 1: # only use flash attention for batch_size = 1 now # as flash attention doesn't support attn_mask in ipex 2.1, # so it will cause output error for padded batch input return False - if q_len == 1: + if q_len != k_len: # now only use flash attention for first token # as it seems have no performance benifit for rest token now return False