LLM: add flash attention for mistral / mixtral (#9846)

* add flash attention for mistral

* update

* add flash attn for mixtral

* fix style
This commit is contained in:
Ruonan Wang 2024-01-08 09:51:34 +08:00 committed by GitHub
parent afaa871144
commit dc995006cc
2 changed files with 112 additions and 39 deletions

View file

@ -40,6 +40,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
@ -47,6 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\
is_enough_kv_cache_room_4_36 is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.transformers.low_bit_linear import SYM_INT4
from bigdl.llm.transformers.models.utils import use_flash_attention
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -126,8 +128,10 @@ def mistral_attention_forward(
use_cache: bool=False, use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None, padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value)
@ -214,11 +218,33 @@ def mistral_attention_forward(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads if not self.training and not hidden_states.requires_grad:
key_states = repeat_kv(key_states, self.num_key_value_groups) fsdp_flag = use_flash_attention(query_states, key_states)
value_states = repeat_kv(value_states, self.num_key_value_groups) else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
if fsdp_flag:
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
key_states,
value_states,
is_causal=True)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
else:
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
key_states,
value_states,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.num_heads, self.head_dim, self.num_heads, self.head_dim,
self.hidden_size, attention_mask) self.hidden_size, attention_mask)
@ -228,7 +254,7 @@ def mistral_attention_forward(
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output.to(original_dtype), attn_weights, past_key_value
def mistral_attention_forward_4_36( def mistral_attention_forward_4_36(
@ -241,8 +267,10 @@ def mistral_attention_forward_4_36(
use_cache: bool=False, use_cache: bool=False,
padding_mask: Optional[torch.Tensor]=None, padding_mask: Optional[torch.Tensor]=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, hidden_size = hidden_states.size()
device = hidden_states.device device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
@ -345,18 +373,42 @@ def mistral_attention_forward_4_36(
past_key_value.key_cache[self.layer_idx] = key_states past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads if not self.training and not hidden_states.requires_grad:
key_states = repeat_kv(key_states, self.num_key_value_groups) fsdp_flag = use_flash_attention(query_states, key_states)
value_states = repeat_kv(value_states, self.num_key_value_groups) else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
if fsdp_flag:
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
key_states,
value_states,
is_causal=True)
attn_weights = None
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
else:
attn_output, attn_weights = compute_attn_outputs_weights(query_states,
key_states,
value_states,
bsz, q_len, kv_seq_len, bsz, q_len, kv_seq_len,
self.num_heads, self.head_dim, self.num_heads,
self.hidden_size, attention_mask) self.head_dim,
self.hidden_size,
attention_mask)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
if not output_attentions: if not output_attentions:
attn_weights = None attn_weights = None
return attn_output, attn_weights, past_key_value return attn_output.to(original_dtype), attn_weights, past_key_value

View file

@ -49,6 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36 apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path
from bigdl.llm.transformers.models.utils import use_flash_attention
KV_CACHE_ALLOC_BLOCK_LENGTH = 256 KV_CACHE_ALLOC_BLOCK_LENGTH = 256
@ -142,6 +143,8 @@ def mixtral_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
device = hidden_states.device device = hidden_states.device
# for flash attention
original_dtype = hidden_states.dtype
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
@ -241,10 +244,28 @@ def mixtral_attention_forward(
past_key_value.key_cache[self.layer_idx] = key_states past_key_value.key_cache[self.layer_idx] = key_states
past_key_value.value_cache[self.layer_idx] = value_states past_key_value.value_cache[self.layer_idx] = value_states
# repeat k/v heads if n_kv_heads < n_heads if not self.training and not hidden_states.requires_grad:
key_states = repeat_kv(key_states, self.num_key_value_groups) fsdp_flag = use_flash_attention(query_states, key_states)
value_states = repeat_kv(value_states, self.num_key_value_groups) else:
fsdp_flag = False
if fsdp_flag:
attention_dtype = torch.float16 # use fp16 for flash attention
else:
attention_dtype = original_dtype
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
value_states = repeat_kv(value_states, self.num_key_value_groups).to(device,
dtype=attention_dtype)
if fsdp_flag:
attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype),
key_states,
value_states,
is_causal=True)
attn_weights = None
else:
attn_weights = torch.matmul( attn_weights = torch.matmul(
query_states, query_states,
key_states.transpose(2, 3)) / math.sqrt(self.head_dim) key_states.transpose(2, 3)) / math.sqrt(self.head_dim)