From 763413b7e15e524caf6addb5fe62a35364a46a6b Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Tue, 23 Apr 2024 16:13:25 +0800 Subject: [PATCH] LLM: support llama split tensor for long context in transformers>=4.36. (#10844) * LLm: support llama split tensor for long context in transformers>=4.36. * fix dtype. * fix style. * fix style. * fix style. * fix style. * fix dtype. * fix style. --- .../ipex_llm/transformers/models/chatglm2.py | 6 +- .../src/ipex_llm/transformers/models/llama.py | 56 ++++++++++--------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index e6d4ae01..787ed9ce 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -258,8 +258,8 @@ def chatglm2_quantized_attention_forward_8eb45c( query_split = torch.split(query_layer, block_size, dim=1) key_split = torch.split(key, block_size, dim=1) value_split = torch.split(value, block_size, dim=1) - context_layer = torch.empty(batch_size, n_head, - seq_len, head_dim).to(query_layer.device) + context_layer = torch.empty(batch_size, n_head, seq_len, + head_dim, dtype=key.dtype).to(query_layer.device) idx = 0 for q, k, v in zip(query_split, key_split, value_split): if attention_mask is None: @@ -543,7 +543,7 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask value_split = torch.split(value_layer, block_size, dim=1) batch_size, n_head, seq_len, head_dim = query_layer.shape context_layer = torch.empty(batch_size, n_head, seq_len, - head_dim).to(query_layer.device).to(key_layer.dtype) + head_dim, dtype=key_layer.dtype).to(query_layer.device) idx = 0 for q, k, v in zip(query_split, key_split, value_split): result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index e542c6c5..41d11815 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -1028,35 +1028,41 @@ def llama_attention_forward_4_36_quantized( if len(past_key_value.key_cache) <= self.layer_idx: repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, - repeated_key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + if should_split_qkv_tensor(query_states, output_attentions): + attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, + repeated_value_states, attention_mask, + bsz, q_len, kv_seq_len, self.head_dim, + self.num_heads) + else: + attn_weights = torch.matmul(query_states, repeated_key_states + .transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - invalidInputError( - False, - f"Attention weights should be of size " - f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): invalidInputError( False, - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" + f"Attention weights should be of size " + f"{(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - if kv_seq_len >= 2048 or bsz >= 64: - # for memory considerations, do not upcast attention to fp32 - # for long sequences or large batches - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - else: - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, repeated_value_states) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + invalidInputError( + False, + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + if kv_seq_len >= 2048 or bsz >= 64: + # for memory considerations, do not upcast attention to fp32 + # for long sequences or large batches + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, repeated_value_states) if use_cache: cache_kwargs = None key_states, value_states = past_key_value.update(key_states, value_states, @@ -1438,7 +1444,7 @@ def native_sdp_split_qkv_tensor(query, key, value, attention_mask, attn_weights_split = torch.matmul(attn_weights_split, v) attn_output[:, idx:idx+block_actual_size, :, :] = attn_weights_split idx = idx + block_actual_size - return attn_output, None + return attn_output.to(key.dtype), None def llama_model_selective_batching_forward_4_31(