LLM: add flash attention support for llama (#9518)
* add initial flash attention for llama * accelerate fp32 first token by changing to fp16 in advance * support fp32
This commit is contained in:
parent
bf579507c2
commit
b63aae8a8e
1 changed files with 68 additions and 24 deletions
|
|
@ -106,6 +106,16 @@ def llama_attention_forward_4_31(
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
if not self.training and not hidden_states.requires_grad:
|
||||||
|
fsdp_flag = check_flash_attention_available(hidden_states)
|
||||||
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag and q_len > 1:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
if self.config.pretraining_tp > 1:
|
if self.config.pretraining_tp > 1:
|
||||||
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
||||||
|
|
@ -194,31 +204,23 @@ def llama_attention_forward_4_31(
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=attention_dtype)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
dtype=hidden_states.dtype)
|
dtype=attention_dtype)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states,
|
if fsdp_flag and q_len > 1:
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
# now only use flash attention for first token
|
||||||
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len)
|
key_states,
|
||||||
if attn_weights.size() != attn_weights_size:
|
value_states,
|
||||||
invalidInputError(False,
|
is_causal=True)
|
||||||
f"Attention weights should be of size {attn_weights_size}, "
|
attn_weights = None
|
||||||
f"but is {attn_weights.size()}")
|
else:
|
||||||
|
# otherwise, use native attention
|
||||||
if attention_mask is not None:
|
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
|
||||||
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
attention_mask,
|
||||||
if attention_mask.size() != attn_mask_size:
|
bsz, q_len, kv_seq_len,
|
||||||
invalidInputError(False,
|
self.head_dim, self.num_heads)
|
||||||
f"Attention mask should be of size {attn_mask_size}, "
|
|
||||||
f"but is {attention_mask.size()}")
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
|
||||||
if attn_output.size() != attn_output_size:
|
if attn_output.size() != attn_output_size:
|
||||||
|
|
@ -241,4 +243,46 @@ def llama_attention_forward_4_31(
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def check_flash_attention_available(query):
|
||||||
|
# check whether ipex flash attention can be used
|
||||||
|
if query.device.type != "xpu":
|
||||||
|
# ipex flash attention only support for xpu
|
||||||
|
return False
|
||||||
|
ipex_version = get_ipex_version()
|
||||||
|
if ipex_version <= "2.0.110+xpu":
|
||||||
|
# ipex flash attention is supported from ipex 2.1
|
||||||
|
return False
|
||||||
|
if not torch.xpu.has_xetla():
|
||||||
|
# ipex flash attention is only supported for xetla
|
||||||
|
# may update this later
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def native_sdp(query, key, value, attention_mask,
|
||||||
|
bsz, q_len, kv_seq_len, head_dim, num_heads):
|
||||||
|
attn_weights = torch.matmul(query,
|
||||||
|
key.transpose(2, 3)) / math.sqrt(head_dim)
|
||||||
|
|
||||||
|
attn_weights_size = (bsz, num_heads, q_len, kv_seq_len)
|
||||||
|
if attn_weights.size() != attn_weights_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention weights should be of size {attn_weights_size}, "
|
||||||
|
f"but is {attn_weights.size()}")
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_mask_size = (bsz, 1, q_len, kv_seq_len)
|
||||||
|
if attention_mask.size() != attn_mask_size:
|
||||||
|
invalidInputError(False,
|
||||||
|
f"Attention mask should be of size {attn_mask_size}, "
|
||||||
|
f"but is {attention_mask.size()}")
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
|
dtype=torch.float32).to(value.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue