diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index a5f7e021..1189c21e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -106,6 +106,16 @@ def llama_attention_forward_4_31( ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype + if not self.training and not hidden_states.requires_grad: + fsdp_flag = check_flash_attention_available(hidden_states) + else: + fsdp_flag = False + if fsdp_flag and q_len > 1: + attention_dtype = torch.float16 # use fp16 for flash attention + else: + attention_dtype = original_dtype if self.config.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp @@ -194,31 +204,23 @@ def llama_attention_forward_4_31( # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=hidden_states.dtype) + dtype=attention_dtype) value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=hidden_states.dtype) + dtype=attention_dtype) - attn_weights = torch.matmul(query_states, - key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - attn_weights_size = (bsz, self.num_heads, q_len, kv_seq_len) - if attn_weights.size() != attn_weights_size: - invalidInputError(False, - f"Attention weights should be of size {attn_weights_size}, " - f"but is {attn_weights.size()}") - - if attention_mask is not None: - attn_mask_size = (bsz, 1, q_len, kv_seq_len) - if attention_mask.size() != attn_mask_size: - invalidInputError(False, - f"Attention mask should be of size {attn_mask_size}, " - f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) + if fsdp_flag and q_len > 1: + # now only use flash attention for first token + attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), + key_states, + value_states, + is_causal=True) + attn_weights = None + else: + # otherwise, use native attention + attn_output, attn_weights = native_sdp(query_states, key_states, value_states, + attention_mask, + bsz, q_len, kv_seq_len, + self.head_dim, self.num_heads) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: @@ -241,4 +243,46 @@ def llama_attention_forward_4_31( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(original_dtype), attn_weights, past_key_value + + +def check_flash_attention_available(query): + # check whether ipex flash attention can be used + if query.device.type != "xpu": + # ipex flash attention only support for xpu + return False + ipex_version = get_ipex_version() + if ipex_version <= "2.0.110+xpu": + # ipex flash attention is supported from ipex 2.1 + return False + if not torch.xpu.has_xetla(): + # ipex flash attention is only supported for xetla + # may update this later + return False + return True + + +def native_sdp(query, key, value, attention_mask, + bsz, q_len, kv_seq_len, head_dim, num_heads): + attn_weights = torch.matmul(query, + key.transpose(2, 3)) / math.sqrt(head_dim) + + attn_weights_size = (bsz, num_heads, q_len, kv_seq_len) + if attn_weights.size() != attn_weights_size: + invalidInputError(False, + f"Attention weights should be of size {attn_weights_size}, " + f"but is {attn_weights.size()}") + + if attention_mask is not None: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value.dtype) + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights