diff --git a/python/llm/src/bigdl/llm/transformers/models/mistral.py b/python/llm/src/bigdl/llm/transformers/models/mistral.py index 840d2dfb..e769e51c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mistral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mistral.py @@ -40,6 +40,7 @@ from typing import Optional, Tuple import torch from torch import nn +import torch.nn.functional as F from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ @@ -47,6 +48,7 @@ from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31,\ is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.low_bit_linear import SYM_INT4 +from bigdl.llm.transformers.models.utils import use_flash_attention KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -126,8 +128,10 @@ def mistral_attention_forward( use_cache: bool=False, padding_mask: Optional[torch.Tensor]=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value) @@ -214,21 +218,43 @@ def mistral_attention_forward( past_key_value = (key_states, value_states) if use_cache else None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype - attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, - bsz, q_len, kv_seq_len, - self.num_heads, self.head_dim, - self.hidden_size, attention_mask) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + + if fsdp_flag: + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), + key_states, + value_states, + is_causal=True) + attn_weights = None + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, hidden_size) + else: + attn_output, attn_weights = compute_attn_outputs_weights(query_states, + key_states, + value_states, + bsz, q_len, kv_seq_len, + self.num_heads, self.head_dim, + self.hidden_size, attention_mask) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(original_dtype), attn_weights, past_key_value def mistral_attention_forward_4_36( @@ -241,8 +267,10 @@ def mistral_attention_forward_4_36( use_cache: bool=False, padding_mask: Optional[torch.Tensor]=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + bsz, q_len, hidden_size = hidden_states.size() device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) @@ -345,18 +373,42 @@ def mistral_attention_forward_4_36( past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype - attn_output, attn_weights = compute_attn_outputs_weights(query_states, key_states, value_states, - bsz, q_len, kv_seq_len, - self.num_heads, self.head_dim, - self.hidden_size, attention_mask) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + + if fsdp_flag: + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), + key_states, + value_states, + is_causal=True) + attn_weights = None + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, hidden_size) + else: + attn_output, attn_weights = compute_attn_outputs_weights(query_states, + key_states, + value_states, + bsz, q_len, kv_seq_len, + self.num_heads, + self.head_dim, + self.hidden_size, + attention_mask) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(original_dtype), attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index cebea709..36251834 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -49,6 +49,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ apply_rotary_pos_emb_no_cache_xpu, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.mistral import should_use_fuse_rope, use_decoding_fast_path +from bigdl.llm.transformers.models.utils import use_flash_attention KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -142,6 +143,8 @@ def mixtral_attention_forward( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) @@ -241,35 +244,53 @@ def mixtral_attention_forward( past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states + if not self.training and not hidden_states.requires_grad: + fsdp_flag = use_flash_attention(query_states, key_states) + else: + fsdp_flag = False + if fsdp_flag: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype + # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) + value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, + dtype=attention_dtype) - attn_weights = torch.matmul( - query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if fsdp_flag: + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), + key_states, + value_states, + is_causal=True) + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," - f" but is {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): invalidInputError( False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}," + f" but is {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) - # upcast attention to fp32 - attn_weights = nn.functional.\ - softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.\ + softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(