From 24473e331a3e5eb12351d733c95bf2c073b0782e Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Fri, 15 Mar 2024 13:12:03 +0800 Subject: [PATCH] Qwen2 fp16 sdp (#10427) * qwen2 sdp and refine * update * update * fix style * remove use_flash_attention --- .../bigdl/llm/transformers/models/llama.py | 36 +++++------- .../src/bigdl/llm/transformers/models/qwen.py | 9 ++- .../bigdl/llm/transformers/models/qwen2.py | 58 +++++++++++-------- 3 files changed, 52 insertions(+), 51 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 0b15e066..378a7cb5 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -604,20 +604,19 @@ def llama_attention_forward_4_31_original( past_key_value = (key_states, value_states) if use_cache else None - fsdp_flag = not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query_states, key_states, attention_mask) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if fsdp_flag: + if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask): attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), key_states.to(device, dtype=torch.float16), value_states.to(device, dtype=torch.float16), is_causal=True) attn_weights = None - elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + elif not self.training and not hidden_states.requires_grad and \ + use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, key_states, @@ -1249,29 +1248,20 @@ def llama_attention_forward_4_36_original( past_key_value.key_cache[self.layer_idx] = key_states past_key_value.value_cache[self.layer_idx] = value_states - if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) - else: - fsdp_flag = False - if fsdp_flag: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - if fsdp_flag: + if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask): # now only use flash attention for first token - attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), - key_states, - value_states, + attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), + key_states.to(device, dtype=torch.float16), + value_states.to(device, dtype=torch.float16), is_causal=True) attn_weights = None - elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + elif not self.training and not hidden_states.requires_grad and \ + use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query_states, key_states, diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 8c348a90..a2ca1bbe 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -273,10 +273,8 @@ def qwen_attention_forward_original( if not decoding_fast_path: query = query.transpose(1, 2) - fsdp_flag = not self.training and not hidden_states.requires_grad and \ - use_flash_attention(query, key) - - if fsdp_flag: + if not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query, key): attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16), key.to(device, dtype=torch.float16), value.to(device, dtype=torch.float16), @@ -284,7 +282,8 @@ def qwen_attention_forward_original( attn_output = attn_output.view(query.shape) attn_output = attn_output.transpose(1, 2) attn_weights = None - elif use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): + elif not self.training and not hidden_states.requires_grad and \ + use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): import linear_fp16_esimd attn_output = linear_fp16_esimd.sdp_forward(query, key, diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index e3482119..9df13355 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -43,6 +43,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List import torch import torch.nn as nn +import torch.nn.functional as F from bigdl.llm.transformers.models.llama import repeat_kv from bigdl.llm.transformers.models.utils import extend_kv_cache, append_kv_cache @@ -51,6 +52,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu from bigdl.llm.transformers.kv import DynamicFp8Cache from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from transformers.models.qwen2.modeling_qwen2 import Qwen2Model, apply_rotary_pos_emb @@ -345,34 +347,44 @@ def qwen2_attention_forward_origin( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if not self.training and not hidden_states.requires_grad and \ + use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query_states, + key_states, + value_states) + attn_output = attn_output.view(query_states.shape) + attn_weights = None + else: + attn_weights = torch.matmul(query_states, + key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), - ("Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}," - "but is {attn_weights.size()}")) + invalidInputError(attn_weights.size() == (bsz, self.num_heads, q_len, kv_seq_len), + ("Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}," + "but is {attn_weights.size()}")) - if attention_mask is not None: - invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), - (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" - f" but is {attention_mask.size()}")) + if attention_mask is not None: + invalidInputError(attention_mask.size() == (bsz, 1, q_len, kv_seq_len), + (f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}" + f" but is {attention_mask.size()}")) - attn_weights = attn_weights + attention_mask + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = \ - nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.attention_dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = \ + nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, + p=self.attention_dropout, + training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), - "`attn_output` should be of size " - f"{(bsz, self.num_heads, q_len, self.head_dim)}," - f" but is {attn_output.size()}") + invalidInputError(attn_output.size() == (bsz, self.num_heads, q_len, self.head_dim), + "`attn_output` should be of size " + f"{(bsz, self.num_heads, q_len, self.head_dim)}," + f" but is {attn_output.size()}") - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -380,7 +392,7 @@ def qwen2_attention_forward_origin( if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output.to(hidden_states.dtype), attn_weights, past_key_value def qwen2_sdpa_attention_forward(