LLM: add flash attention for mistral / mixtral (#9846)
* add flash attention for mistral * update * add flash attn for mixtral * fix style
This commit is contained in:
parent
afaa871144
commit
dc995006cc
2 changed files with 112 additions and 39 deletions
|
|
@ -40,6 +40,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
|
|
@ -47,6 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
|
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
|
||||||
is_enough_kv_cache_room_4_36
|
is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||||
|
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
|
|
@ -126,8 +128,10 @@ def mistral_attention_forward(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
padding_mask: Optional[torch.Tensor]=None,
|
padding_mask: Optional[torch.Tensor]=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
|
||||||
|
|
@ -214,21 +218,43 @@ def mistral_attention_forward(
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
if not self.training and not hidden_states.requires_grad:
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
fsdp_flag = use_flash_attention(query_states, key_states)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states,
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
bsz, q_len, kv_seq_len,
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
self.num_heads, self.head_dim,
|
dtype=attention_dtype)
|
||||||
self.hidden_size, attention_mask)
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
|
||||||
|
if fsdp_flag:
|
||||||
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
is_causal=True)
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||||
|
else:
|
||||||
|
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bsz, q_len, kv_seq_len,
|
||||||
|
self.num_heads, self.head_dim,
|
||||||
|
self.hidden_size, attention_mask)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def mistral_attention_forward_4_36(
|
def mistral_attention_forward_4_36(
|
||||||
|
|
@ -241,8 +267,10 @@ def mistral_attention_forward_4_36(
|
||||||
use_cache: bool=False,
|
use_cache: bool=False,
|
||||||
padding_mask: Optional[torch.Tensor]=None,
|
padding_mask: Optional[torch.Tensor]=None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, hidden_size = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
|
@ -345,18 +373,42 @@ def mistral_attention_forward_4_36(
|
||||||
past_key_value.key_cache[self.layer_idx] = key_states
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
if not self.training and not hidden_states.requires_grad:
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
fsdp_flag = use_flash_attention(query_states, key_states)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states,
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
bsz, q_len, kv_seq_len,
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
self.num_heads, self.head_dim,
|
dtype=attention_dtype)
|
||||||
self.hidden_size, attention_mask)
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
|
||||||
|
if fsdp_flag:
|
||||||
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
is_causal=True)
|
||||||
|
attn_weights = None
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
|
||||||
|
else:
|
||||||
|
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
bsz, q_len, kv_seq_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
attention_mask)
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output.to(original_dtype), attn_weights, past_key_value
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
|
apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
|
||||||
|
from bigdl.llm.transformers.models.utils import use_flash_attention
|
||||||
|
|
||||||
|
|
||||||
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
@ -142,6 +143,8 @@ def mixtral_attention_forward(
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
# for flash attention
|
||||||
|
original_dtype = hidden_states.dtype
|
||||||
|
|
||||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
|
@ -241,35 +244,53 @@ def mixtral_attention_forward(
|
||||||
past_key_value.key_cache[self.layer_idx] = key_states
|
past_key_value.key_cache[self.layer_idx] = key_states
|
||||||
past_key_value.value_cache[self.layer_idx] = value_states
|
past_key_value.value_cache[self.layer_idx] = value_states
|
||||||
|
|
||||||
|
if not self.training and not hidden_states.requires_grad:
|
||||||
|
fsdp_flag = use_flash_attention(query_states, key_states)
|
||||||
|
else:
|
||||||
|
fsdp_flag = False
|
||||||
|
if fsdp_flag:
|
||||||
|
attention_dtype = torch.float16 # use fp16 for flash attention
|
||||||
|
else:
|
||||||
|
attention_dtype = original_dtype
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
dtype=attention_dtype)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
|
||||||
|
dtype=attention_dtype)
|
||||||
|
|
||||||
attn_weights = torch.matmul(
|
if fsdp_flag:
|
||||||
query_states,
|
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
|
||||||
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
key_states,
|
||||||
|
value_states,
|
||||||
|
is_causal=True)
|
||||||
|
attn_weights = None
|
||||||
|
else:
|
||||||
|
attn_weights = torch.matmul(
|
||||||
|
query_states,
|
||||||
|
key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
invalidInputError(
|
|
||||||
False,
|
|
||||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
|
||||||
f" but is {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
False,
|
False,
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)},"
|
||||||
f" but is {attention_mask.size()}"
|
f" but is {attn_weights.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_weights = attn_weights + attention_mask
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)},"
|
||||||
|
f" but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
# upcast attention to fp32
|
attn_weights = attn_weights + attention_mask
|
||||||
attn_weights = nn.functional.\
|
|
||||||
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
# upcast attention to fp32
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_weights = nn.functional.\
|
||||||
|
softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
invalidInputError(
|
invalidInputError(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue