diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 56ccc211..538cf1b0 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -137,10 +137,10 @@ def llama_attention_forward_4_31( # for flash attention original_dtype = hidden_states.dtype if not self.training and not hidden_states.requires_grad: - fsdp_flag = check_flash_attention_available(hidden_states) + fsdp_flag = use_flash_attention(hidden_states) else: fsdp_flag = False - if fsdp_flag and q_len > 1: + if fsdp_flag: attention_dtype = torch.float16 # use fp16 for flash attention else: attention_dtype = original_dtype @@ -261,8 +261,7 @@ def llama_attention_forward_4_31( value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, dtype=attention_dtype) - if fsdp_flag and q_len > 1: - # now only use flash attention for first token + if fsdp_flag: attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), key_states, value_states, @@ -325,7 +324,7 @@ def llama_attention_selective_batching_forward_4_31( original_dtype = hidden_states.dtype # TODO: consider this later - flash attention # if not self.training and not hidden_states.requires_grad: - # fsdp_flag = check_flash_attention_available(hidden_states) + # fsdp_flag = use_flash_attention(hidden_states) # else: # fsdp_flag = False # if fsdp_flag and q_len > 1: @@ -506,8 +505,18 @@ def llama_attention_selective_batching_forward_4_31( return attn_output.to(original_dtype), attn_weights, updated_past_key_values -def check_flash_attention_available(query): +def use_flash_attention(query): + bsz, q_len, _ = query.size() # check whether ipex flash attention can be used + if bsz > 1: + # only use flash attention for batch_size = 1 now + # as flash attention doesn't support attn_mask in ipex 2.1, + # so it will cause output error for padded batch input + return False + if q_len == 1: + # now only use flash attention for first token + # as it seems have no performance benifit for rest token now + return False if query.device.type != "xpu": # ipex flash attention only support for xpu return False