From 4b024b7aac4a2890af4a6753e130973ae9769e55 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Wed, 10 Apr 2024 16:59:06 +0800 Subject: [PATCH] LLM: optimize chatglm2 8k input. (#10723) * LLM: optimize chatglm2 8k input. * rename. --- .../ipex_llm/transformers/models/chatglm2.py | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 196473ae..9812926f 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -252,10 +252,31 @@ def chatglm2_quantized_attention_forward_8eb45c( else: key, value = key_layer, value_layer - if attention_mask is None: - context_layer = F.scaled_dot_product_attention(query_layer, key, value, is_causal=True) + # split tensor for memory block limitation + # support fp16 and set input length threshold at 5000 for now + if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000: + # split second dim to block size = 8 + block_size = 8 + query_split = torch.split(query_layer, block_size, dim=1) + key_split = torch.split(key, block_size, dim=1) + value_split = torch.split(value, block_size, dim=1) + context_layer = torch.empty(batch_size, n_head, + seq_len, head_dim).to(query_layer.device) + idx = 0 + for q, k, v in zip(query_split, key_split, value_split): + if attention_mask is None: + result = F.scaled_dot_product_attention(q, k, v, is_causal=True) + else: + result = F.scaled_dot_product_attention(q, k, v, attention_mask) + context_layer[:, idx:idx+q.shape[1], :, :] = result + idx = idx + q.shape[1] else: - context_layer = F.scaled_dot_product_attention(query_layer, key, value, attention_mask) + if attention_mask is None: + context_layer = F.scaled_dot_product_attention(query_layer, key, + value, is_causal=True) + else: + context_layer = F.scaled_dot_product_attention(query_layer, key, + value, attention_mask) context_layer = context_layer.to(query_layer.dtype) if use_cache: @@ -517,15 +538,19 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask # split tensor for memory block limitation # support fp16 and set input length threshold at 5000 for now if query_layer.dtype == torch.float16 and L >= 5000: - # split first dim 32 -> 8 - query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1) - key_sp = torch.split(key_layer, 8, dim=1) - value_sp = torch.split(value_layer, 8, dim=1) - results = [] - for q, k, v in zip(query_sp, key_sp, value_sp): + # split second dim to block size = 8 + block_size = 8 + query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) + key_split = torch.split(key_layer, block_size, dim=1) + value_split = torch.split(value_layer, block_size, dim=1) + batch_size, n_head, seq_len, head_dim = query_layer.shape + context_layer = torch.empty(batch_size, n_head, seq_len, + head_dim).to(query_layer.device).to(key_layer.dtype) + idx = 0 + for q, k, v in zip(query_split, key_split, value_split): result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) - results.append(result) - context_layer = torch.cat(results, dim=1) + context_layer[:, idx:idx+q.shape[1], :, :] = result + idx = idx + q.shape[1] else: context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), key_layer,