diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index e76b3c71..7dabab3a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -349,7 +349,7 @@ def llama_attention_forward_4_31_quantized( cos, sin, position_ids, "llama") if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states) + fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) else: fsdp_flag = False if fsdp_flag: @@ -629,7 +629,7 @@ def llama_attention_forward_4_31_original( past_key_value = (key_states, value_states) if use_cache else None if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states) + fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) else: fsdp_flag = False if fsdp_flag: @@ -1068,7 +1068,7 @@ def llama_attention_forward_4_36( past_key_value.value_cache[self.layer_idx] = value_states if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states) + fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) else: fsdp_flag = False if fsdp_flag: diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 916e855b..de320e9b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -236,7 +236,7 @@ def is_enough_kv_cache_room_4_31(past_key_value, seq_len=1): (past_key_value[0].size(2) + seq_len) * past_key_value[0].size(3) -def use_flash_attention(query, key): +def use_flash_attention(query, key, attention_mask=None): # here we support query's shape is always [batch_size, head_num, q_len, head_dim], # key's shape is always [batch_size, head_num, k_len, head_dim] invalidInputError(query.dim() == 4, @@ -248,11 +248,6 @@ def use_flash_attention(query, key): bsz, _, q_len, _ = query.size() k_len = key.size()[2] # check whether ipex flash attention can be used - if bsz > 1: - # only use flash attention for batch_size = 1 now - # as flash attention doesn't support attn_mask in ipex 2.1, - # so it will cause output error for padded batch input - return False if q_len != k_len: # now only use flash attention for first token # as it seems have no performance benifit for rest token now @@ -271,6 +266,22 @@ def use_flash_attention(query, key): if query.dtype not in [torch.float32, torch.float16]: # only use flash attention for fp32/fp16 input return False + if bsz > 1: + # as flash attention doesn't support attn_mask in ipex 2.1, + # so it will cause output error for padded batch input + if attention_mask is None: + return True + else: + # TODO: below logic may change for different model + # attention mask shape : [bsz, 1, q_len, k_len] + if attention_mask[0].squeeze()[0, 0].item() != 0: + # first batch contains padding + # otherwise we suppose it should be a upper triangular matrix + # at the same time, the diagonal is also 0 + return False + elif not attention_mask.equal(attention_mask[0].repeat(bsz, 1, 1, 1)): + # check whether mask of every batch is the same + return False return True