LLM: fix wrong batch output caused by flash attention (#9780)

* fix

* meet code review

* move batch size check to the beginning

* move qlen check inside function

* meet code review
This commit is contained in:
Ruonan Wang 2023-12-26 09:41:27 +08:00 committed by GitHub
parent 66e286a73d
commit 11d883301b

View file

@ -137,10 +137,10 @@ def llama_attention_forward_4_31(
# for flash attention # for flash attention
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
if not self.training and not hidden_states.requires_grad: if not self.training and not hidden_states.requires_grad:
fsdp_flag = check_flash_attention_available(hidden_states) fsdp_flag = use_flash_attention(hidden_states)
else: else:
fsdp_flag = False fsdp_flag = False
if fsdp_flag and q_len > 1: if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention attention_dtype = torch.float16 # use fp16 for flash attention
else: else:
attention_dtype = original_dtype attention_dtype = original_dtype
@ -261,8 +261,7 @@ def llama_attention_forward_4_31(
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype) dtype=attention_dtype)
if fsdp_flag and q_len > 1: if fsdp_flag:
# now only use flash attention for first token
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
key_states, key_states,
value_states, value_states,
@ -325,7 +324,7 @@ def llama_attention_selective_batching_forward_4_31(
original_dtype = hidden_states.dtype original_dtype = hidden_states.dtype
# TODO: consider this later - flash attention # TODO: consider this later - flash attention
# if not self.training and not hidden_states.requires_grad: # if not self.training and not hidden_states.requires_grad:
# fsdp_flag = check_flash_attention_available(hidden_states) # fsdp_flag = use_flash_attention(hidden_states)
# else: # else:
# fsdp_flag = False # fsdp_flag = False
# if fsdp_flag and q_len > 1: # if fsdp_flag and q_len > 1:
@ -506,8 +505,18 @@ def llama_attention_selective_batching_forward_4_31(
return attn_output.to(original_dtype), attn_weights, updated_past_key_values return attn_output.to(original_dtype), attn_weights, updated_past_key_values
def check_flash_attention_available(query): def use_flash_attention(query):
bsz, q_len, _ = query.size()
# check whether ipex flash attention can be used # check whether ipex flash attention can be used
if bsz > 1:
# only use flash attention for batch_size = 1 now
# as flash attention doesn't support attn_mask in ipex 2.1,
# so it will cause output error for padded batch input
return False
if q_len == 1:
# now only use flash attention for first token
# as it seems have no performance benifit for rest token now
return False
if query.device.type != "xpu": if query.device.type != "xpu":
# ipex flash attention only support for xpu # ipex flash attention only support for xpu
return False return False