add compresskv back for mistral (#12607)
* add compresskv back for mistral * fix * fix
This commit is contained in:
parent
9c9800be31
commit
4e6b9d804f
2 changed files with 29 additions and 6 deletions
|
|
@ -37,6 +37,7 @@
|
|||
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import os
|
||||
import torch
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
|
|
@ -45,8 +46,11 @@ from transformers.models.mistral.modeling_mistral import MistralModel, MistralAt
|
|||
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||
from ipex_llm.transformers.models.common import scaled_dot_product_attention
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
|
||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||
|
||||
|
||||
def mistral_model_forward(
|
||||
|
|
@ -69,11 +73,22 @@ def mistral_model_forward(
|
|||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
||||
self.config.num_attention_heads //
|
||||
self.config.num_key_value_heads)
|
||||
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||
if use_quantize_kv:
|
||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
elif (
|
||||
not use_quantize_kv
|
||||
and not use_compress_kv
|
||||
and not isinstance(past_key_values, DynamicNormalCache)
|
||||
):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
# ipex-llm changes end
|
||||
|
||||
|
|
@ -127,8 +142,16 @@ def mistral_attention_forward(
|
|||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids, "mistral")
|
||||
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
if isinstance(past_key_value, DynamicCompressCache):
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx,
|
||||
query_states, attention_mask, self.num_key_value_groups,
|
||||
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||
)
|
||||
else:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
# IPEX-LLM OPT: sdpa
|
||||
attn_weights = None
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ from ipex_llm.ggml.quantize import ggml_tensor_qtype
|
|||
from ipex_llm.utils.common import invalidInputError
|
||||
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from ipex_llm.transformers.models.utils import apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.models.mistral import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_decoding_fast_path
|
||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
|
||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||
|
|
@ -171,7 +171,7 @@ def mixtral_attention_forward(
|
|||
# for flash attention
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
use_fuse_rope = should_use_fuse_rope(hidden_states, position_ids, self.training)
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||
decoding_fast_path = use_decoding_fast_path(self.q_proj,
|
||||
use_fuse_rope,
|
||||
|
|
|
|||
Loading…
Reference in a new issue