From 1a9b8204a44ecf2b9f572f7774add946fe6b247e Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Sun, 7 Apr 2024 09:39:21 +0800 Subject: [PATCH] LLM: support int4 fp16 chatglm2-6b 8k input. (#10648) --- .../ipex_llm/transformers/models/chatglm2.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 1c0c670a..3d69cd18 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -512,10 +512,23 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask query_layer = query_layer.permute(1, 2, 0, 3) L, S = query_layer.shape[2], key_layer.shape[2] if attention_mask is None and L == S: - context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), - key_layer, - value_layer, - is_causal=True).to(key_layer.dtype) + # split tensor for memory block limitation + # support fp16 and set input length threshold at 5000 for now + if query_layer.dtype == torch.float16 and L >= 5000: + # split first dim 32 -> 8 + query_sp = torch.split(query_layer.to(key_layer.dtype), 8, dim=1) + key_sp = torch.split(key_layer, 8, dim=1) + value_sp = torch.split(value_layer, 8, dim=1) + results = [] + for q, k, v in zip(query_sp, key_sp, value_sp): + result = F.scaled_dot_product_attention(q, k, v, is_causal=True).to(k.dtype) + results.append(result) + context_layer = torch.cat(results, dim=1) + else: + context_layer = F.scaled_dot_product_attention(query_layer.to(key_layer.dtype), + key_layer, + value_layer, + is_causal=True).to(key_layer.dtype) else: if use_esimd_sdp(query_layer.shape[2], key_layer.shape[2], query_layer.shape[-1], query_layer):