LLM: fix wrong batch output caused by flash attention (#9780)
* fix * meet code review * move batch size check to the beginning * move qlen check inside function * meet code review
This commit is contained in:
parent
66e286a73d
commit
11d883301b
1 changed files with 15 additions and 6 deletions
|
|
@ -137,10 +137,10 @@ def llama_attention_forward_4_31(
|
|||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
if not self.training and not hidden_states.requires_grad:
|
||||
fsdp_flag = check_flash_attention_available(hidden_states)
|
||||
fsdp_flag = use_flash_attention(hidden_states)
|
||||
else:
|
||||
fsdp_flag = False
|
||||
if fsdp_flag and q_len > 1:
|
||||
if fsdp_flag:
|
||||
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||
else:
|
||||
attention_dtype = original_dtype
|
||||
|
|
@ -261,8 +261,7 @@ def llama_attention_forward_4_31(
|
|||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||
dtype=attention_dtype)
|
||||
|
||||
if fsdp_flag and q_len > 1:
|
||||
# now only use flash attention for first token
|
||||
if fsdp_flag:
|
||||
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||
key_states,
|
||||
value_states,
|
||||
|
|
@ -325,7 +324,7 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
original_dtype = hidden_states.dtype
|
||||
# TODO: consider this later - flash attention
|
||||
# if not self.training and not hidden_states.requires_grad:
|
||||
# fsdp_flag = check_flash_attention_available(hidden_states)
|
||||
# fsdp_flag = use_flash_attention(hidden_states)
|
||||
# else:
|
||||
# fsdp_flag = False
|
||||
# if fsdp_flag and q_len > 1:
|
||||
|
|
@ -506,8 +505,18 @@ def llama_attention_selective_batching_forward_4_31(
|
|||
return attn_output.to(original_dtype), attn_weights, updated_past_key_values
|
||||
|
||||
|
||||
def check_flash_attention_available(query):
|
||||
def use_flash_attention(query):
|
||||
bsz, q_len, _ = query.size()
|
||||
# check whether ipex flash attention can be used
|
||||
if bsz > 1:
|
||||
# only use flash attention for batch_size = 1 now
|
||||
# as flash attention doesn't support attn_mask in ipex 2.1,
|
||||
# so it will cause output error for padded batch input
|
||||
return False
|
||||
if q_len == 1:
|
||||
# now only use flash attention for first token
|
||||
# as it seems have no performance benifit for rest token now
|
||||
return False
|
||||
if query.device.type != "xpu":
|
||||
# ipex flash attention only support for xpu
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue