From c0cd238e40427bd981efdc76764ec58d97c44318 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Mon, 8 Apr 2024 11:43:15 +0800 Subject: [PATCH] LLM: support llama2 8k input with w4a16. (#10677) * LLM: support llama2 8k input with w4a16. * fix comment and style. * fix style. * fix comments and split tensor to quantized attention forward. * fix style. * refactor name. * fix style. * fix style. * fix style. * refactor checker name. * refactor native sdp split qkv tensor name. * fix style. * fix comment rename variables. * fix co-exist of intermedia results. --- .../src/ipex_llm/transformers/models/llama.py | 110 ++++++++++++------ 1 file changed, 76 insertions(+), 34 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index ec554874..56ac4e8a 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -214,6 +214,15 @@ def should_use_fast_rope(self, query_states, position_ids): return use_fuse_rope +def should_split_qkv_tensor(query_states, output_attentions): + if not output_attentions and query_states.dtype == torch.float16 and \ + query_states.shape[2] >= 6800: + # split tensor for memory block limitation + # support fp16 and set input length threshold at 6800 for now + return True + return False + + def llama_decoder_forward( self, hidden_states: torch.Tensor, @@ -404,7 +413,7 @@ def llama_attention_forward_4_31_quantized( attn_output, attn_weights = native_sdp(query_states, repeated_key_states, repeated_value_states, attention_mask, bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads) + self.head_dim, self.num_heads, output_attentions) if use_cache: k_cache, v_cache = init_fp8_kv_cache( bsz, self.num_key_value_heads, kv_seq_len, self.head_dim, @@ -429,7 +438,7 @@ def llama_attention_forward_4_31_quantized( attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads) + self.head_dim, self.num_heads, output_attentions) else: import linear_q4_0 attn_output = linear_q4_0.sdp_fp8(query_states, key_states, value_states, @@ -642,8 +651,7 @@ def llama_attention_forward_4_31_original( attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads) - + self.head_dim, self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: invalidInputError(False, @@ -814,7 +822,8 @@ def llama_attention_selective_batching_forward_4_31( 1, current_kv_len, self.head_dim, - self.num_heads) + self.num_heads, + output_attentions) if attn_output.size() != (1, self.num_heads, 1, self.head_dim): invalidInputError(False, f"`attn_output` should be of size " @@ -858,7 +867,8 @@ def llama_attention_selective_batching_forward_4_31( q_len, kv_seq_len, self.head_dim, - self.num_heads) + self.num_heads, + output_attentions) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): invalidInputError(False, @@ -1291,7 +1301,7 @@ def llama_attention_forward_4_36_original( attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads) + self.head_dim, self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: @@ -1318,33 +1328,65 @@ def llama_attention_forward_4_36_original( def native_sdp(query, key, value, attention_mask, - bsz, q_len, kv_seq_len, head_dim, num_heads): - attn_weights = torch.matmul(query.to(key.dtype), - key.transpose(2, 3)) / math.sqrt(head_dim) - - attn_weights_size = (bsz, num_heads, q_len, kv_seq_len) - if attn_weights.size() != attn_weights_size: - invalidInputError(False, - f"Attention weights should be of size {attn_weights_size}, " - f"but is {attn_weights.size()}") - - if attention_mask is not None: - attn_mask_size = (bsz, 1, q_len, kv_seq_len) - if attention_mask.size() != attn_mask_size: - invalidInputError(False, - f"Attention mask should be of size {attn_mask_size}, " - f"but is {attention_mask.size()}") - attn_weights = attn_weights + attention_mask - - if kv_seq_len >= 2048: - # for memory considerations, do not upcast attention to fp32 for long sequences - attn_weights = nn.functional.softmax(attn_weights, dim=-1) + bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions): + if should_split_qkv_tensor(query, output_attentions): + return native_sdp_split_qkv_tensor(query, key, value, attention_mask, + bsz, q_len, kv_seq_len, head_dim) else: - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(value.dtype) - attn_output = torch.matmul(attn_weights, value) - return attn_output, attn_weights + attn_weights = torch.matmul(query.to(key.dtype), + key.transpose(2, 3)) / math.sqrt(head_dim) + + attn_weights_size = (bsz, num_heads, q_len, kv_seq_len) + if attn_weights.size() != attn_weights_size: + invalidInputError(False, + f"Attention weights should be of size {attn_weights_size}, " + f"but is {attn_weights.size()}") + + if attention_mask is not None: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights = attn_weights + attention_mask + + if kv_seq_len >= 2048: + # for memory considerations, do not upcast attention to fp32 for long sequences + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(value.dtype) + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +def native_sdp_split_qkv_tensor(query, key, value, attention_mask, + bsz, q_len, kv_seq_len, head_dim): + query_split = torch.split(query.to(key.dtype), 16, dim=1) + key_split = torch.split(key.transpose(2, 3), 16, dim=1) + value_split = torch.split(value, 16, dim=1) + attn_outputs = [] + for q, k, v in zip(query_split, key_split, value_split): + attn_weights_split = torch.matmul(q, k) / math.sqrt(head_dim) + attn_weights_split_size = (bsz, 16, q_len, kv_seq_len) + if attn_weights_split.size() != attn_weights_split_size: + invalidInputError(False, + f"Splitted attention weights should be of size " + f"{attn_weights_split_size}, but is {attn_weights_split.size()}") + + if attention_mask is not None: + attn_mask_size = (bsz, 1, q_len, kv_seq_len) + if attention_mask.size() != attn_mask_size: + invalidInputError(False, + f"Attention mask should be of size {attn_mask_size}, " + f"but is {attention_mask.size()}") + attn_weights_split = attn_weights_split + attention_mask + attn_weights_split = nn.functional.softmax(attn_weights_split, dim=-1) + attn_weights_split = torch.matmul(attn_weights_split, v) + attn_outputs.append(attn_weights_split) + attn_output = torch.cat(attn_outputs, dim=1) + return attn_output, None def llama_model_selective_batching_forward_4_31( @@ -1601,7 +1643,7 @@ def llama_attention_fast_forward( attn_output, attn_weights = native_sdp(query_states, key_states, value_states, attention_mask, bsz, q_len, kv_seq_len, - self.head_dim, self.num_heads) + self.head_dim, self.num_heads, output_attentions) attn_output_size = (bsz, self.num_heads, q_len, self.head_dim) if attn_output.size() != attn_output_size: