LLM: fix wrong batch output caused by flash attention (#9780)
* fix * meet code review * move batch size check to the beginning * move qlen check inside function * meet code review
This commit is contained in:
parent
66e286a73d
commit
11d883301b
1 changed files with 15 additions and 6 deletions
|
|
@ -137,10 +137,10 @@ def llama_attention_forward_4_31(
|
||||||
# for flash attention
|
# for flash attention
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
if not self.training and not hidden_states.requires_grad:
|
if not self.training and not hidden_states.requires_grad:
|
||||||
fsdp_flag = check_flash_attention_available(hidden_states)
|
fsdp_flag = use_flash_attention(hidden_states)
|
||||||
else:
|
else:
|
||||||
fsdp_flag = False
|
fsdp_flag = False
|
||||||
if fsdp_flag and q_len > 1:
|
if fsdp_flag:
|
||||||
attention_dtype = torch.float16 # use fp16 for flash attention
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
else:
|
else:
|
||||||
attention_dtype = original_dtype
|
attention_dtype = original_dtype
|
||||||
|
|
@ -261,8 +261,7 @@ def llama_attention_forward_4_31(
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
dtype=attention_dtype)
|
dtype=attention_dtype)
|
||||||
|
|
||||||
if fsdp_flag and q_len > 1:
|
if fsdp_flag:
|
||||||
# now only use flash attention for first token
|
|
||||||
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|
@ -325,7 +324,7 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
original_dtype = hidden_states.dtype
|
original_dtype = hidden_states.dtype
|
||||||
# TODO: consider this later - flash attention
|
# TODO: consider this later - flash attention
|
||||||
# if not self.training and not hidden_states.requires_grad:
|
# if not self.training and not hidden_states.requires_grad:
|
||||||
# fsdp_flag = check_flash_attention_available(hidden_states)
|
# fsdp_flag = use_flash_attention(hidden_states)
|
||||||
# else:
|
# else:
|
||||||
# fsdp_flag = False
|
# fsdp_flag = False
|
||||||
# if fsdp_flag and q_len > 1:
|
# if fsdp_flag and q_len > 1:
|
||||||
|
|
@ -506,8 +505,18 @@ def llama_attention_selective_batching_forward_4_31(
|
||||||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||||
|
|
||||||
|
|
||||||
def check_flash_attention_available(query):
|
def use_flash_attention(query):
|
||||||
|
bsz, q_len, _ = query.size()
|
||||||
# check whether ipex flash attention can be used
|
# check whether ipex flash attention can be used
|
||||||
|
if bsz > 1:
|
||||||
|
# only use flash attention for batch_size = 1 now
|
||||||
|
# as flash attention doesn't support attn_mask in ipex 2.1,
|
||||||
|
# so it will cause output error for padded batch input
|
||||||
|
return False
|
||||||
|
if q_len == 1:
|
||||||
|
# now only use flash attention for first token
|
||||||
|
# as it seems have no performance benifit for rest token now
|
||||||
|
return False
|
||||||
if query.device.type != "xpu":
|
if query.device.type != "xpu":
|
||||||
# ipex flash attention only support for xpu
|
# ipex flash attention only support for xpu
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue