Chatglm support compresskv (#11690)
* chatglm4 support compresskv * fix * fix style * support chatglm2 * fix quantkv conflict * fix style
This commit is contained in:
		
							parent
							
								
									762ad49362
								
							
						
					
					
						commit
						45c730ff39
					
				
					 2 changed files with 117 additions and 42 deletions
				
			
		| 
						 | 
				
			
			@ -25,6 +25,9 @@ from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		|||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
 | 
			
		||||
    use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			@ -83,6 +86,14 @@ def chatglm2_model_forward(
 | 
			
		|||
        input_ids = torch.empty((batch_size, seq_length),
 | 
			
		||||
                                dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input_ids)
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
 | 
			
		||||
                                                input_ids)
 | 
			
		||||
        if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
 | 
			
		||||
                                                                      DynamicCompressCache):
 | 
			
		||||
            past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
    if full_attention_mask is None:
 | 
			
		||||
        if (attention_mask is not None and not attention_mask.all()) or (
 | 
			
		||||
                past_key_values and seq_length != 1):
 | 
			
		||||
| 
						 | 
				
			
			@ -157,7 +168,10 @@ def chatglm2_encoder_forward(
 | 
			
		|||
    use_cache: Optional[bool] = True,
 | 
			
		||||
    output_hidden_states: Optional[bool] = False,
 | 
			
		||||
):
 | 
			
		||||
    if not kv_caches:
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compress_kv = isinstance(kv_caches, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    if not kv_caches and not use_compress_kv:
 | 
			
		||||
        kv_caches = [None for _ in range(self.num_layers)]
 | 
			
		||||
    presents = () if use_cache else None
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
| 
						 | 
				
			
			@ -184,12 +198,15 @@ def chatglm2_encoder_forward(
 | 
			
		|||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                rotary_pos_emb,
 | 
			
		||||
                kv_cache=kv_caches[index],
 | 
			
		||||
                kv_cache=kv_caches if use_compress_kv else kv_caches[index],
 | 
			
		||||
                use_cache=use_cache
 | 
			
		||||
            )
 | 
			
		||||
        hidden_states, kv_cache = layer_ret
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            presents = presents + (kv_cache,)
 | 
			
		||||
            if use_compress_kv:
 | 
			
		||||
                presents = kv_caches
 | 
			
		||||
            else:
 | 
			
		||||
                presents = presents + (kv_cache,)
 | 
			
		||||
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
| 
						 | 
				
			
			@ -207,10 +224,16 @@ def chatglm2_attention_forward(
 | 
			
		|||
    # hidden_states: [seq_len, bsz, head_dim]
 | 
			
		||||
    q_len, bsz, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = isinstance(kv_cache, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    # kv_cache: [seq_len, bsz, n_kv_head, head_dim] ->
 | 
			
		||||
    # past_key_value: [bsz, n_kv_head, seq_len, head_dim]
 | 
			
		||||
    past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3),
 | 
			
		||||
                                                    kv_cache[1].permute(1, 2, 0, 3))
 | 
			
		||||
    if use_compresskv:
 | 
			
		||||
        past_key_value = kv_cache
 | 
			
		||||
    else:
 | 
			
		||||
        past_key_value = None if kv_cache is None else (kv_cache[0].permute(1, 2, 0, 3),
 | 
			
		||||
                                                        kv_cache[1].permute(1, 2, 0, 3))
 | 
			
		||||
 | 
			
		||||
    n_head = self.num_attention_heads_per_partition
 | 
			
		||||
    n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
 | 
			
		||||
| 
						 | 
				
			
			@ -227,7 +250,11 @@ def chatglm2_attention_forward(
 | 
			
		|||
 | 
			
		||||
    kv_seq_len = key_states.shape[2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
 | 
			
		||||
                                                           self.layer_number - 1)
 | 
			
		||||
        else:
 | 
			
		||||
            kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    inv_freq, position_ids = rotary_pos_emb
 | 
			
		||||
| 
						 | 
				
			
			@ -249,13 +276,23 @@ def chatglm2_attention_forward(
 | 
			
		|||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
    )
 | 
			
		||||
    # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
 | 
			
		||||
    past_key_value = (key_states.permute(2, 0, 1, 3),
 | 
			
		||||
                      value_states.permute(2, 0, 1, 3)) if use_cache else None
 | 
			
		||||
    if use_quantize_kv or (not use_compresskv):
 | 
			
		||||
        key_states, value_states = update_past_key_value(
 | 
			
		||||
            past_key_value, key_states, value_states,
 | 
			
		||||
            kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
        )
 | 
			
		||||
        # past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
 | 
			
		||||
        past_key_value = (key_states.permute(2, 0, 1, 3),
 | 
			
		||||
                          value_states.permute(2, 0, 1, 3)) if use_cache else None
 | 
			
		||||
    else:
 | 
			
		||||
        from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
        self.config = self.config if hasattr(self, "config") else PretrainedConfig()
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
 | 
			
		||||
        key_states, value_states = past_key_value.update(
 | 
			
		||||
            key_states, value_states, self.layer_number - 1,
 | 
			
		||||
            query_states, attention_mask, n_head // n_kv_head,
 | 
			
		||||
            self.config, enough_kv_room, 256
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,9 +20,11 @@
 | 
			
		|||
import torch
 | 
			
		||||
from typing import Optional, Tuple, Union
 | 
			
		||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, use_sdp_causal
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
 | 
			
		||||
    use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.chatglm2 import repeat_kv
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressCache
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
import math
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -46,6 +48,15 @@ def chatglm4_model_forward(
 | 
			
		|||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs)
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
 | 
			
		||||
                                                inputs)
 | 
			
		||||
        if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
 | 
			
		||||
                                                                      DynamicCompressCache):
 | 
			
		||||
            past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape
 | 
			
		||||
        inputs_embeds = self.embedding(input_ids)
 | 
			
		||||
| 
						 | 
				
			
			@ -134,9 +145,15 @@ def chatglm4_attention_forward(
 | 
			
		|||
    # hidden_states: [b, sq, h]
 | 
			
		||||
    bsz, q_len, _ = hidden_states.size()
 | 
			
		||||
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compresskv = isinstance(kv_cache, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    # past_key_value: [bsz, n_kv_head, seq_len, head_dim]
 | 
			
		||||
    past_key_value = None if kv_cache is None else (kv_cache[0],
 | 
			
		||||
                                                    kv_cache[1])
 | 
			
		||||
    if use_compresskv:
 | 
			
		||||
        past_key_value = kv_cache
 | 
			
		||||
    else:
 | 
			
		||||
        past_key_value = None if kv_cache is None else (kv_cache[0],
 | 
			
		||||
                                                        kv_cache[1])
 | 
			
		||||
 | 
			
		||||
    n_head = self.num_attention_heads_per_partition
 | 
			
		||||
    n_kv_head = self.num_multi_query_groups_per_partition if self.multi_query_attention else n_head
 | 
			
		||||
| 
						 | 
				
			
			@ -153,7 +170,11 @@ def chatglm4_attention_forward(
 | 
			
		|||
 | 
			
		||||
    kv_seq_len = key_states.shape[2]
 | 
			
		||||
    if past_key_value is not None:
 | 
			
		||||
        kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
        if use_compresskv:
 | 
			
		||||
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
 | 
			
		||||
                                                           self.layer_number - 1)
 | 
			
		||||
        else:
 | 
			
		||||
            kv_seq_len += past_key_value[0].shape[2]
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: fuse rope
 | 
			
		||||
    inv_freq, position_ids = rotary_pos_emb
 | 
			
		||||
| 
						 | 
				
			
			@ -175,19 +196,29 @@ def chatglm4_attention_forward(
 | 
			
		|||
 | 
			
		||||
    # IPEX-LLM OPT: kv cache and quantize kv
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
 | 
			
		||||
    key_states, value_states = update_past_key_value(
 | 
			
		||||
        past_key_value, key_states, value_states,
 | 
			
		||||
        kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if past_key_value is None:
 | 
			
		||||
            past_key_value = torch.cat((key_states.unsqueeze(0).unsqueeze(0),
 | 
			
		||||
                                        value_states.unsqueeze(0).unsqueeze(0)), dim=1)
 | 
			
		||||
    if use_quantize_kv or (not use_compresskv):
 | 
			
		||||
        key_states, value_states = update_past_key_value(
 | 
			
		||||
            past_key_value, key_states, value_states,
 | 
			
		||||
            kv_seq_len, use_quantize_kv, hidden_states.device
 | 
			
		||||
        )
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if past_key_value is None:
 | 
			
		||||
                past_key_value = torch.cat((key_states.unsqueeze(0).unsqueeze(0),
 | 
			
		||||
                                            value_states.unsqueeze(0).unsqueeze(0)), dim=1)
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_value = (key_states, value_states)
 | 
			
		||||
        else:
 | 
			
		||||
            past_key_value = (key_states, value_states)
 | 
			
		||||
            past_key_value = None
 | 
			
		||||
    else:
 | 
			
		||||
        past_key_value = None
 | 
			
		||||
        from transformers.configuration_utils import PretrainedConfig
 | 
			
		||||
        self.config = self.config if hasattr(self, "config") else PretrainedConfig()
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_number - 1)
 | 
			
		||||
        key_states, value_states = past_key_value.update(
 | 
			
		||||
            key_states, value_states, self.layer_number - 1,
 | 
			
		||||
            query_states, attention_mask, n_head // n_kv_head,
 | 
			
		||||
            self.config, enough_kv_room, 256
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdp
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			@ -244,7 +275,10 @@ def chatglm4_encoder_forward(
 | 
			
		|||
    use_cache: Optional[bool] = True,
 | 
			
		||||
    output_hidden_states: Optional[bool] = False,
 | 
			
		||||
):
 | 
			
		||||
    if not kv_caches:
 | 
			
		||||
    # [CompressKV]
 | 
			
		||||
    use_compress_kv = isinstance(kv_caches, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    if not kv_caches and not use_compress_kv:
 | 
			
		||||
        kv_caches = [None for _ in range(self.num_layers)]
 | 
			
		||||
    presents = () if use_cache else None
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
| 
						 | 
				
			
			@ -274,26 +308,30 @@ def chatglm4_encoder_forward(
 | 
			
		|||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                rotary_pos_emb,
 | 
			
		||||
                kv_cache=kv_caches[index],
 | 
			
		||||
                kv_cache=kv_caches if use_compress_kv else kv_caches[index],
 | 
			
		||||
                use_cache=use_cache
 | 
			
		||||
            )
 | 
			
		||||
        hidden_states, kv_cache = layer_ret
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            # token by token decoding, use tuple format
 | 
			
		||||
            if kv_caches[0] is not None:
 | 
			
		||||
                presents = presents + (kv_cache,)
 | 
			
		||||
            # prefilling in decoding, use tensor format to save cuda memory
 | 
			
		||||
            if use_compress_kv:
 | 
			
		||||
                presents = kv_caches
 | 
			
		||||
            else:
 | 
			
		||||
                if len(presents) == 0:
 | 
			
		||||
                    presents = kv_cache
 | 
			
		||||
                # token by token decoding, use tuple format
 | 
			
		||||
                if kv_caches[0] is not None:
 | 
			
		||||
                    presents = presents + (kv_cache,)
 | 
			
		||||
                # prefilling in decoding, use tensor format to save cuda memory
 | 
			
		||||
                else:
 | 
			
		||||
                    # bigdl-llm change starts
 | 
			
		||||
                    # to fix first token's kv cache error of tensor format in pipeline parallel
 | 
			
		||||
                    if isinstance(kv_cache, tuple):
 | 
			
		||||
                        kv_cache = torch.tensor(kv_cache,
 | 
			
		||||
                                                dtype=hidden_states.dtype).to(hidden_states.device)
 | 
			
		||||
                    # bigdl-llm change ends
 | 
			
		||||
                    presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
 | 
			
		||||
                    if len(presents) == 0:
 | 
			
		||||
                        presents = kv_cache
 | 
			
		||||
                    else:
 | 
			
		||||
                        # bigdl-llm change starts
 | 
			
		||||
                        # to fix first token's kv cache error of tensor format in pipeline parallel
 | 
			
		||||
                        if isinstance(kv_cache, tuple):
 | 
			
		||||
                            kv_cache = torch.tensor(
 | 
			
		||||
                                kv_cache,
 | 
			
		||||
                                dtype=hidden_states.dtype).to(hidden_states.device)
 | 
			
		||||
                        # bigdl-llm change ends
 | 
			
		||||
                        presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
 | 
			
		||||
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states = all_hidden_states + (hidden_states,)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue