From cda38f85a978ba7601a928506ac33833c25d3606 Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Fri, 15 Mar 2024 08:51:50 +0800 Subject: [PATCH] Qwen fp16 sdp (#10401) * qwen sdp * fix * update * update * update sdp * update * fix style check * add to origin type --- .../bigdl/llm/transformers/models/llama.py | 22 +++++-------- .../src/bigdl/llm/transformers/models/qwen.py | 31 ++++++++++++++++--- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 91dab9b5..0b15e066 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -604,25 +604,17 @@ def llama_attention_forward_4_31_original( past_key_value = (key_states, value_states) if use_cache else None - if not self.training and not hidden_states.requires_grad: - fsdp_flag = use_flash_attention(query_states, key_states, attention_mask) - else: - fsdp_flag = False - if fsdp_flag: - attention_dtype = torch.float16 # use fp16 for flash attention - else: - attention_dtype = original_dtype + fsdp_flag = not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query_states, key_states, attention_mask) # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) - value_states = repeat_kv(value_states, self.num_key_value_groups).to(device, - dtype=attention_dtype) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) if fsdp_flag: - attn_output = F.scaled_dot_product_attention(query_states.to(dtype=attention_dtype), - key_states, - value_states, + attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16), + key_states.to(device, dtype=torch.float16), + value_states.to(device, dtype=torch.float16), is_causal=True) attn_weights = None elif use_esimd_sdp(q_len, key_states.shape[2], self.head_dim, query_states): diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index fce6e20a..8c348a90 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -42,6 +42,7 @@ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, append_fp8_kv from bigdl.llm.transformers.models.utils import rotate_half, SILU from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_cache_freq_xpu +from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -131,6 +132,8 @@ def qwen_attention_forward_original( "flash attn and kv_cache quantization are not supported") bsz, q_len, _ = hidden_states.size() device = hidden_states.device + # for flash attention + original_dtype = hidden_states.dtype use_fuse_rope = should_use_fuse_rope(self, hidden_states) decoding_fast_path = (use_fuse_rope and bsz * q_len == 1) @@ -270,15 +273,35 @@ def qwen_attention_forward_original( if not decoding_fast_path: query = query.transpose(1, 2) - attn_output, attn_weight = self._attn( - query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask - ) + fsdp_flag = not self.training and not hidden_states.requires_grad and \ + use_flash_attention(query, key) + + if fsdp_flag: + attn_output = F.scaled_dot_product_attention(query.to(device, dtype=torch.float16), + key.to(device, dtype=torch.float16), + value.to(device, dtype=torch.float16), + is_causal=True) + attn_output = attn_output.view(query.shape) + attn_output = attn_output.transpose(1, 2) + attn_weights = None + elif use_esimd_sdp(q_len, key.shape[2], self.head_dim, query): + import linear_fp16_esimd + attn_output = linear_fp16_esimd.sdp_forward(query, + key, + value) + attn_output = attn_output.view(query.shape) + attn_output = attn_output.transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query.to(key.dtype), key, value, causal_mask, attention_mask, head_mask + ) context_layer = self._merge_heads( attn_output, self.num_heads, self.head_dim ) - attn_output = self.c_proj(context_layer) + attn_output = self.c_proj(context_layer).to(original_dtype) if use_cache: outputs = (attn_output, (key.transpose(1, 2), value.transpose(1, 2)))