From 45c730ff39e96481f9e162929cc6d0e30422f3bd Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Thu, 1 Aug 2024 13:20:20 +0300 Subject: [PATCH] Chatglm support compresskv (#11690) * chatglm4 support compresskv * fix * fix style * support chatglm2 * fix quantkv conflict * fix style --- .../ipex_llm/transformers/models/chatglm2.py | 63 +++++++++--- .../ipex_llm/transformers/models/chatglm4.py | 96 +++++++++++++------ 2 files changed, 117 insertions(+), 42 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 22a4adb9..006ff331 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -25,6 +25,9 @@ from ipex_llm.utils.common.log4Error import invalidInputError from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ + use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.kv import DynamicCompressCache def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -83,6 +86,14 @@ def chatglm2_model_forward( input_ids = torch.empty((batch_size, seq_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device) + if use_cache: + use_compress_kv = should_use_compresskv(input_ids) + use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, + input_ids) + if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, + DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or ( past_key_values and seq_length != 1): @@ -157,7 +168,10 @@ def chatglm2_encoder_forward( use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, ): - if not kv_caches: + # [CompressKV] + use_compress_kv = isinstance(kv_caches, DynamicCompressCache) + + if not kv_caches and not use_compress_kv: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None if self.gradient_checkpointing and self.training: @@ -184,12 +198,15 @@ def chatglm2_encoder_forward( hidden_states, attention_mask, rotary_pos_emb, - kv_cache=kv_caches[index], + kv_cache=kv_caches if use_compress_kv else kv_caches[index], use_cache=use_cache ) hidden_states, kv_cache = layer_ret if use_cache: - presents = presents + (kv_cache,) + if use_compress_kv: + presents = kv_caches + else: + presents = presents + (kv_cache,) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -207,10 +224,16 @@ def chatglm2_attention_forward( # hidden_states: [seq_len, bsz, head_dim] q_len, bsz, _ = hidden_states.size() + # [CompressKV] + use_compresskv = isinstance(kv_cache, DynamicCompressCache) + # kv_cache: [seq_len, bsz, n_kv_head, head_dim] -> # past_key_value: [bsz, n_kv_head, seq_len, head_dim] - past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3), - kv_cache[1].permute(1, 2, 0, 3)) + if use_compresskv: + past_key_value = kv_cache + else: + past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3), + kv_cache[1].permute(1, 2, 0, 3)) n_head = self.num_attention_heads_per_partition n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head @@ -227,7 +250,11 @@ def chatglm2_attention_forward( kv_seq_len = key_states.shape[2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[2] + if use_compresskv: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, + self.layer_number - 1) + else: + kv_seq_len += past_key_value[0].shape[2] # IPEX-LLM OPT: fuse rope inv_freq, position_ids = rotary_pos_emb @@ -249,13 +276,23 @@ def chatglm2_attention_forward( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, hidden_states.device - ) - # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim] - past_key_value = (key_states.permute(2, 0, 1, 3), - value_states.permute(2, 0, 1, 3)) if use_cache else None + if use_quantize_kv or (not use_compresskv): + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, hidden_states.device + ) + # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim] + past_key_value = (key_states.permute(2, 0, 1, 3), + value_states.permute(2, 0, 1, 3)) if use_cache else None + else: + from transformers.configuration_utils import PretrainedConfig + self.config = self.config if hasattr(self, "config") else PretrainedConfig() + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_number - 1, + query_states, attention_mask, n_head // n_kv_head, + self.config, enough_kv_room, 256 + ) # IPEX-LLM OPT: sdp attn_weights = None diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index f567d17e..361d405c 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -20,9 +20,11 @@ import torch from typing import Optional, Tuple, Union from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value -from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal +from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \ + use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36 from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb from ipex_llm.transformers.models.chatglm2 import repeat_kv +from ipex_llm.transformers.kv import DynamicCompressCache from transformers.modeling_outputs import BaseModelOutputWithPast import math @@ -46,6 +48,15 @@ def chatglm4_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if use_cache: + inputs = input_ids if input_ids is not None else inputs_embeds + use_compress_kv = should_use_compresskv(inputs) + use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, + inputs) + if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, + DynamicCompressCache): + past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) + if inputs_embeds is None: batch_size, seq_length = input_ids.shape inputs_embeds = self.embedding(input_ids) @@ -134,9 +145,15 @@ def chatglm4_attention_forward( # hidden_states: [b, sq, h] bsz, q_len, _ = hidden_states.size() + # [CompressKV] + use_compresskv = isinstance(kv_cache, DynamicCompressCache) + # past_key_value: [bsz, n_kv_head, seq_len, head_dim] - past_key_value = None if kv_cache is None else (kv_cache[0], - kv_cache[1]) + if use_compresskv: + past_key_value = kv_cache + else: + past_key_value = None if kv_cache is None else (kv_cache[0], + kv_cache[1]) n_head = self.num_attention_heads_per_partition n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head @@ -153,7 +170,11 @@ def chatglm4_attention_forward( kv_seq_len = key_states.shape[2] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[2] + if use_compresskv: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, + self.layer_number - 1) + else: + kv_seq_len += past_key_value[0].shape[2] # IPEX-LLM OPT: fuse rope inv_freq, position_ids = rotary_pos_emb @@ -175,19 +196,29 @@ def chatglm4_attention_forward( # IPEX-LLM OPT: kv cache and quantize kv use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) - key_states, value_states = update_past_key_value( - past_key_value, key_states, value_states, - kv_seq_len, use_quantize_kv, hidden_states.device - ) - if use_cache: - if past_key_value is None: - past_key_value = torch.cat((key_states.unsqueeze(0).unsqueeze(0), - value_states.unsqueeze(0).unsqueeze(0)), dim=1) + if use_quantize_kv or (not use_compresskv): + key_states, value_states = update_past_key_value( + past_key_value, key_states, value_states, + kv_seq_len, use_quantize_kv, hidden_states.device + ) + if use_cache: + if past_key_value is None: + past_key_value = torch.cat((key_states.unsqueeze(0).unsqueeze(0), + value_states.unsqueeze(0).unsqueeze(0)), dim=1) + else: + past_key_value = (key_states, value_states) else: - past_key_value = (key_states, value_states) + past_key_value = None else: - past_key_value = None + from transformers.configuration_utils import PretrainedConfig + self.config = self.config if hasattr(self, "config") else PretrainedConfig() + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_number - 1, + query_states, attention_mask, n_head // n_kv_head, + self.config, enough_kv_room, 256 + ) # IPEX-LLM OPT: sdp attn_weights = None @@ -244,7 +275,10 @@ def chatglm4_encoder_forward( use_cache: Optional[bool] = True, output_hidden_states: Optional[bool] = False, ): - if not kv_caches: + # [CompressKV] + use_compress_kv = isinstance(kv_caches, DynamicCompressCache) + + if not kv_caches and not use_compress_kv: kv_caches = [None for _ in range(self.num_layers)] presents = () if use_cache else None if self.gradient_checkpointing and self.training: @@ -274,26 +308,30 @@ def chatglm4_encoder_forward( hidden_states, attention_mask, rotary_pos_emb, - kv_cache=kv_caches[index], + kv_cache=kv_caches if use_compress_kv else kv_caches[index], use_cache=use_cache ) hidden_states, kv_cache = layer_ret if use_cache: - # token by token decoding, use tuple format - if kv_caches[0] is not None: - presents = presents + (kv_cache,) - # prefilling in decoding, use tensor format to save cuda memory + if use_compress_kv: + presents = kv_caches else: - if len(presents) == 0: - presents = kv_cache + # token by token decoding, use tuple format + if kv_caches[0] is not None: + presents = presents + (kv_cache,) + # prefilling in decoding, use tensor format to save cuda memory else: - # bigdl-llm change starts - # to fix first token's kv cache error of tensor format in pipeline parallel - if isinstance(kv_cache, tuple): - kv_cache = torch.tensor(kv_cache, - dtype=hidden_states.dtype).to(hidden_states.device) - # bigdl-llm change ends - presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0) + if len(presents) == 0: + presents = kv_cache + else: + # bigdl-llm change starts + # to fix first token's kv cache error of tensor format in pipeline parallel + if isinstance(kv_cache, tuple): + kv_cache = torch.tensor( + kv_cache, + dtype=hidden_states.dtype).to(hidden_states.device) + # bigdl-llm change ends + presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,)