diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 9032bbf5..1367e3a3 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -96,6 +96,19 @@ def repeat_kv(key: torch.Tensor, value: torch.Tensor, n_head: int) -> (torch.Ten return key, value +def should_split_qkv_tensor(query_layer, bsz, n_head, seq_len): + if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None: + return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1" + elif query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000: + # split tensor for memory block limitation + # support fp16 and set input length threshold at 5000 for now + return True + elif query_layer.element_size()*bsz*n_head*seq_len*seq_len >= 4*1024**3: + # attn_weight size larger than memory block limitation 4GB + return True + return False + + def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 @@ -250,9 +263,7 @@ def chatglm2_quantized_attention_forward_8eb45c( else: key, value = key_layer, value_layer - # split tensor for memory block limitation - # support fp16 and set input length threshold at 5000 for now - if query_layer.dtype == torch.float16 and query_layer.shape[2] >= 5000: + if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): # split second dim to block size = 8 block_size = 8 query_split = torch.split(query_layer, block_size, dim=1) @@ -529,9 +540,8 @@ def core_attn_forward_8eb45c(query_layer, key_layer, value_layer, attention_mask query_layer = query_layer.permute(1, 2, 0, 3) L, S = query_layer.shape[2], key_layer.shape[2] if attention_mask is None and L == S: - # split tensor for memory block limitation - # support fp16 and set input length threshold at 5000 for now - if query_layer.dtype == torch.float16 and L >= 5000: + batch_size, n_head, seq_len, head_dim = query_layer.shape + if should_split_qkv_tensor(query_layer, batch_size, n_head, seq_len): # split second dim to block size = 8 block_size = 8 query_split = torch.split(query_layer.to(key_layer.dtype), block_size, dim=1) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 703d3163..e3e06fe9 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -215,12 +215,18 @@ def should_use_fast_rope(self, query_states, position_ids): return use_fuse_rope -def should_split_qkv_tensor(query_states, output_attentions): - if not output_attentions and query_states.dtype == torch.float16 and \ - query_states.shape[2] >= 6800: - # split tensor for memory block limitation - # support fp16 and set input length threshold at 6800 for now - return True +def should_split_qkv_tensor(query_states, bsz, num_heads, q_len, kv_seq_len, output_attentions): + if not output_attentions: + if os.environ.get("IPEX_LLM_SPLIT_QKV", None) is not None: + return os.environ.get("IPEX_LLM_SPLIT_QKV", None) == "1" + elif query_states.dtype == torch.float16 and \ + query_states.shape[2] >= 6800: + # split tensor for memory block limitation + # support fp16 and set input length threshold at 6800 for now + return True + elif query_states.element_size()*bsz*num_heads*q_len*kv_seq_len >= 4*1024**3: + # attn_weight size larger than memory block limitation 4GB + return True return False @@ -1029,7 +1035,8 @@ def llama_attention_forward_4_36_quantized( if len(past_key_value.key_cache) <= self.layer_idx: repeated_key_states = repeat_kv(key_states, self.num_key_value_groups) repeated_value_states = repeat_kv(value_states, self.num_key_value_groups) - if should_split_qkv_tensor(query_states, output_attentions): + if should_split_qkv_tensor(query_states, bsz, self.num_heads, + q_len, kv_seq_len, output_attentions): attn_output, _ = native_sdp_split_qkv_tensor(query_states, repeated_key_states, repeated_value_states, attention_mask, bsz, q_len, kv_seq_len, self.head_dim, @@ -1403,7 +1410,7 @@ def llama_attention_forward_4_36_original( def native_sdp(query, key, value, attention_mask, bsz, q_len, kv_seq_len, head_dim, num_heads, output_attentions): - if should_split_qkv_tensor(query, output_attentions): + if should_split_qkv_tensor(query, bsz, num_heads, q_len, kv_seq_len, output_attentions): return native_sdp_split_qkv_tensor(query, key, value, attention_mask, bsz, q_len, kv_seq_len, head_dim, num_heads) else: diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index e2c84488..2db6bd1b 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -74,9 +74,7 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: - if os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: - return os.environ["IPEX_LLM_LOW_MEM"] == "1" - elif os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: + if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: warnings.warn( "`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. " "Please use `IPEX_LLM_QUANTIZE_KV_CACHE` instead." @@ -84,6 +82,8 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor) -> bool: return os.environ["BIGDL_QUANTIZE_KV_CACHE"] == "1" elif os.environ.get("IPEX_LLM_QUANTIZE_KV_CACHE", None) is not None: return os.environ["IPEX_LLM_QUANTIZE_KV_CACHE"] == "1" + elif os.environ.get("IPEX_LLM_LOW_MEM", None) is not None: + return os.environ["IPEX_LLM_LOW_MEM"] == "1" else: return x.device.type == 'xpu' and kv_cache_device_check(x) \ and hasattr(linear, "qtype") and \