diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index 44a10d56..60a02912 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -24,8 +24,10 @@ from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn +import torch.nn.functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv_cache, \ @@ -267,25 +269,43 @@ def baichuan_attention_forward_7b_origin( past_key_value = (key_states, value_states) if use_cache else None - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask): + attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), + key_states.to(device, dtype=torch.float16), + value_states.to(device, dtype=torch.float16), + is_causal=True) + attn_weights = None + elif not self.training and not hidden_states.requires_grad and \ + use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_states, + key_states, + value_states) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError(False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}" - f", but is {attn_weights.size()}") + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + invalidInputError(False, + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}" + f", but is {attn_weights.size()}") - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " - f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + if attention_mask is not None: + invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, + torch.tensor(torch.finfo(attn_weights.dtype).min)) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), f"`attn_output` should be of size " @@ -300,7 +320,7 @@ def baichuan_attention_forward_7b_origin( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(hidden_states.dtype), attn_weights, past_key_value def baichuan_attention_forward_13b( @@ -502,4 +522,4 @@ def baichuan_attention_forward_13b_origin( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(hidden_states.dtype), attn_weights, past_key_value diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 1fc8e403..844577a6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -28,6 +28,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv restore_fp8_kv_cache, use_quantize_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, \ append_kv_cache, is_enough_kv_cache_room_4_31 +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb, SILU from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import mlp_fusion_check @@ -271,16 +272,32 @@ def baichuan_attention_forward_7b_origin( query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() ) else: - if attention_mask is not None: - if attention_mask.dtype == torch.bool: - attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) + if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask): + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=torch.float16), + key_states.to(dtype=torch.float16), + value_states.to(dtype=torch.float16), + is_causal=True) + attn_weights = None + elif not self.training and not hidden_states.requires_grad and \ + use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_states, + key_states, + value_states) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + else: + if attention_mask is not None: + if attention_mask.dtype == torch.bool: + attention_mask.masked_fill_(attention_mask.logical_not(), float("-inf")) - scaling_factor = 1 / math.sqrt(query_states.size(-1)) - attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1)) - if attention_mask is not None: - attn_output += attention_mask - attn_output = torch.softmax(attn_output, -1) - attn_output = torch.matmul(attn_output, value_states) + scaling_factor = 1 / math.sqrt(query_states.size(-1)) + attn_output = torch.matmul(query_states * scaling_factor, key_states.transpose(-2, -1)) + if attention_mask is not None: + attn_output += attention_mask + attn_output = torch.softmax(attn_output, -1) + attn_output = torch.matmul(attn_output, value_states) attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) @@ -289,7 +306,7 @@ def baichuan_attention_forward_7b_origin( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(hidden_states.dtype), attn_weights, past_key_value def baichuan_attention_forward_13b( diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index 9df13355..371e5029 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -348,6 +348,13 @@ def qwen2_attention_forward_origin( value_states = repeat_kv(value_states, self.num_key_value_groups) if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask): + attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), + key_states.to(device, dtype=torch.float16), + value_states.to(device, dtype=torch.float16), + is_causal=True) + attn_weights = None + elif not self.training and not hidden_states.requires_grad and \ use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, @@ -379,12 +386,12 @@ def qwen2_attention_forward_origin( training=self.training) attn_output = torch.matmul(attn_weights, value_states) - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - "`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") + invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), + "`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}," + f" but is {attn_output.size()}") - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output)