add compresskv back for mistral (#12607)
* add compresskv back for mistral * fix * fix
This commit is contained in:
		
							parent
							
								
									9c9800be31
								
							
						
					
					
						commit
						4e6b9d804f
					
				
					 2 changed files with 29 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -37,6 +37,7 @@
 | 
			
		|||
 | 
			
		||||
from typing import Optional, Tuple, Union, List
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.cache_utils import Cache
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast
 | 
			
		||||
| 
						 | 
				
			
			@ -45,8 +46,11 @@ from transformers.models.mistral.modeling_mistral import MistralModel, MistralAt
 | 
			
		|||
from ipex_llm.transformers.models.common import merge_qkv_base
 | 
			
		||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mistral_model_forward(
 | 
			
		||||
| 
						 | 
				
			
			@ -69,11 +73,22 @@ def mistral_model_forward(
 | 
			
		|||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
                                            self.config.num_attention_heads //
 | 
			
		||||
                                            self.config.num_key_value_heads)
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
 | 
			
		||||
        isinstance(past_key_values, DynamicCompressCache)
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
        if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
            if use_quantize_kv:
 | 
			
		||||
                past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
            else:
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
 | 
			
		||||
        elif (
 | 
			
		||||
            not use_quantize_kv
 | 
			
		||||
            and not use_compress_kv
 | 
			
		||||
            and not isinstance(past_key_values, DynamicNormalCache)
 | 
			
		||||
        ):
 | 
			
		||||
            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
			
		||||
    # ipex-llm changes end
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -127,8 +142,16 @@ def mistral_attention_forward(
 | 
			
		|||
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
 | 
			
		||||
                                                        cos, sin, position_ids, "mistral")
 | 
			
		||||
 | 
			
		||||
    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                     self.layer_idx, None)
 | 
			
		||||
    if isinstance(past_key_value, DynamicCompressCache):
 | 
			
		||||
        enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
 | 
			
		||||
        key_states, value_states = past_key_value.update(
 | 
			
		||||
            key_states, value_states, self.layer_idx,
 | 
			
		||||
            query_states, attention_mask, self.num_key_value_groups,
 | 
			
		||||
            self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    # IPEX-LLM OPT: sdpa
 | 
			
		||||
    attn_weights = None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,7 +52,7 @@ from ipex_llm.ggml.quantize import ggml_tensor_qtype
 | 
			
		|||
from ipex_llm.utils.common import invalidInputError
 | 
			
		||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
 | 
			
		||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
 | 
			
		||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
 | 
			
		||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
 | 
			
		||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
 | 
			
		||||
| 
						 | 
				
			
			@ -171,7 +171,7 @@ def mixtral_attention_forward(
 | 
			
		|||
    # for flash attention
 | 
			
		||||
    original_dtype = hidden_states.dtype
 | 
			
		||||
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
 | 
			
		||||
    use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
 | 
			
		||||
    enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
 | 
			
		||||
    decoding_fast_path = use_decoding_fast_path(self.q_proj,
 | 
			
		||||
                                                use_fuse_rope,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue