Support compress KV with quantize KV (#11812)

* update llama

* support llama 4.41

* fix style

* support minicpm

* support qwen2

* support minicpm & update

* support chatglm4

* support chatglm

* remove print

* add DynamicCompressFp8Cache & support qwen

* support llama

* support minicpm phi3

* update chatglm2/4

* small fix & support qwen 4.42

* remove print
This commit is contained in:
Yina Chen 2024-08-19 10:32:32 +03:00 committed by GitHub
parent 6841a9ac8f
commit 3cd4e87168
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 298 additions and 147 deletions

View file

@ -218,8 +218,6 @@ class DynamicCompressCache(DynamicCache):
def __init__(self, quant_kv=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.real_kv_len = 0
self.quant_kv = quant_kv
self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache
def update_seen_tokens(self, layer_idx, q_len):
if layer_idx == 0:
@ -266,38 +264,25 @@ class DynamicCompressCache(DynamicCache):
value_states=value_states,
attention_mask=attention_mask,
num_key_value_groups=num_key_value_groups)
self.key_cache.append(key_states_compress)
self.value_cache.append(value_states_compress)
if not self.quant_kv:
k_cache_compressed, v_cache_compressed = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
key_states.dtype, key_states.device
)
else:
k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
bsz, num_heads, seq_len, head_dim,
device=key_states.device,
)
k_cache_compressed, v_cache_compressed = self.append_kv_func(
k_cache_compressed, v_cache_compressed = append_kv_cache(
k_cache_compressed, v_cache_compressed,
key_states_compress, value_states_compress)
self.key_cache[layer_idx] = k_cache_compressed
self.value_cache[layer_idx] = v_cache_compressed
self.key_cache.append(k_cache_compressed)
self.value_cache.append(v_cache_compressed)
if key_states.stride(2) != head_dim:
if not self.quant_kv:
k_cache, v_cache = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states.size(2),
key_states.dtype, key_states.device
)
else:
k_cache, v_cache = init_fp8_kv_cache(
bsz, num_heads, 0, head_dim, key_states.device
)
k_cache, v_cache = self.append_kv_func(k_cache, v_cache,
k_cache, v_cache = append_kv_cache(k_cache, v_cache,
key_states, value_states)
return k_cache, v_cache
else:
@ -305,7 +290,7 @@ class DynamicCompressCache(DynamicCache):
else:
cache_k = self.key_cache[layer_idx]
cache_v = self.value_cache[layer_idx]
if not enough_kv_room and not self.quant_kv:
if not enough_kv_room:
# allocate new
new_c_k, new_c_v = extend_kv_cache(
bsz,
@ -321,7 +306,7 @@ class DynamicCompressCache(DynamicCache):
cache_k = new_c_k
cache_v = new_c_v
key_states, value_states = self.append_kv_func(cache_k,
key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
@ -339,13 +324,74 @@ class DynamicCompressCache(DynamicCache):
return 0
return self.real_kv_len
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
quantize_kv: Optional[bool] = False) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls(quantize_kv)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
query_states: torch.Tensor,
attention_mask: torch.Tensor,
num_key_value_groups: int,
attn_config: Dict[str, Any],
enough_kv_room: bool,
KV_CACHE_ALLOC_BLOCK_LENGTH: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, num_heads, seq_len, head_dim = key_states.shape
if layer_idx == 0:
if hasattr(self, "_seen_tokens"):
# 4.39 uses `_seen_tokens`
self._seen_tokens += seq_len
else:
# 4.37 uses `seen_tokens`
self.seen_tokens += seq_len
self.real_kv_len += seq_len
# Update the cache
if len(self.key_cache) <= layer_idx:
# First token, compress kv cache
key_states_compress, value_states_compress = compress_kv(
attn_config=attn_config,
key_states=key_states,
query_states=query_states,
value_states=value_states,
attention_mask=attention_mask,
num_key_value_groups=num_key_value_groups)
k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
bsz, num_heads, seq_len, head_dim,
device=key_states.device,
)
k_cache_compressed, v_cache_compressed = append_fp8_kv_cache(
k_cache_compressed, v_cache_compressed,
key_states_compress, value_states_compress)
self.key_cache.append(k_cache_compressed)
self.value_cache.append(v_cache_compressed)
if key_states.stride(2) != head_dim:
k_cache, v_cache = init_fp8_kv_cache(
bsz, num_heads, 0, head_dim, key_states.device
)
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
return k_cache, v_cache
else:
return key_states, value_states
else:
cache_k = self.key_cache[layer_idx]
cache_v = self.value_cache[layer_idx]
key_states, value_states = append_fp8_kv_cache(cache_k,
cache_v,
key_states,
value_states)
# update past_key_value
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
return self.key_cache[layer_idx], self.value_cache[layer_idx]

View file

@ -17,6 +17,7 @@
# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py
#
import os
import math
import torch
from typing import Optional, Tuple
@ -27,7 +28,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, u
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicCompressCache
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -90,8 +93,11 @@ def chatglm2_model_forward(
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
input_ids)
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
if full_attention_mask is None:
@ -279,15 +285,9 @@ def chatglm2_attention_forward(
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
if use_quantize_kv or (not use_compresskv):
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
)
# past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
past_key_value = (key_states.permute(2, 0, 1, 3),
value_states.permute(2, 0, 1, 3)) if use_cache else None
else:
# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
@ -296,8 +296,16 @@ def chatglm2_attention_forward(
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_number - 1,
query_states, attention_mask, n_head // n_kv_head,
self.config, enough_kv_room, 256
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
)
# past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
past_key_value = (key_states.permute(2, 0, 1, 3),
value_states.permute(2, 0, 1, 3)) if use_cache else None
# IPEX-LLM OPT: sdp
attn_weights = None

View file

@ -17,6 +17,7 @@
# https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/configuration_chatglm.py
#
import os
import torch
from typing import Optional, Tuple, Union
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
@ -25,10 +26,12 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
from ipex_llm.transformers.models.chatglm2 import repeat_kv
from ipex_llm.transformers.kv import DynamicCompressCache
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
import math
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
def chatglm4_model_forward(
self,
@ -54,8 +57,11 @@ def chatglm4_model_forward(
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
inputs)
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
if inputs_embeds is None:
@ -201,7 +207,19 @@ def chatglm4_attention_forward(
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
if use_quantize_kv or (not use_compresskv):
# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_number - 1,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_number - 1,
query_states, attention_mask, n_head // n_kv_head,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
@ -214,30 +232,19 @@ def chatglm4_attention_forward(
past_key_value = (key_states, value_states)
else:
past_key_value = None
else:
from transformers.configuration_utils import PretrainedConfig
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
self.layer_number - 1,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_number - 1,
query_states, attention_mask, n_head // n_kv_head,
self.config, enough_kv_room, 256
)
# IPEX-LLM OPT: sdp
attn_weights = None
if use_sdp(q_len, kv_seq_len, head_dim, query_states):
import xe_addons
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
else:
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
import xe_addons
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if use_quantize_kv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
attention_mask)

View file

@ -120,19 +120,25 @@ def llama_model_forward_4_36(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \
DynamicCompressFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now
use_quantize = use_quantize_kv_cache(
self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads)
if should_use_compresskv(input, input.shape[1]):
if not isinstance(past_key_values, DynamicCompressCache):
if use_quantize:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
elif use_quantize:
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return llama_model_forward_4_36_internal(
self=self,
input_ids=input_ids,
@ -160,19 +166,25 @@ def llama_model_forward_4_38(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \
DynamicCompressFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now
use_quantize = use_quantize_kv_cache(
self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads)
if should_use_compresskv(input, input.shape[1]):
if not isinstance(past_key_values, DynamicCompressCache):
if use_quantize:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
elif use_quantize:
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return llama_model_forward_4_38_internal(
self=self,
input_ids=input_ids,
@ -201,19 +213,25 @@ def llama_model_forward_4_41(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \
DynamicCompressFp8Cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
if use_cache:
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads):
if not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif should_use_compresskv(input, input.shape[1]):
# if use quantize kv, compress kv will be ignored now
use_quantize = use_quantize_kv_cache(
self.layers[0].mlp.up_proj, input,
self.config.num_attention_heads//self.config.num_key_value_heads)
if should_use_compresskv(input, input.shape[1]):
if not isinstance(past_key_values, DynamicCompressCache):
if use_quantize:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(
past_key_values)
elif use_quantize:
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
return llama_model_forward_4_41_internal(
self=self,
input_ids=input_ids,
@ -1086,6 +1104,7 @@ def llama_attention_forward_4_41_quantized(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
from ipex_llm.transformers.kv import DynamicCompressCache
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -1102,6 +1121,9 @@ def llama_attention_forward_4_41_quantized(
enough_kv_room,
bsz * q_len,
llama_decoding_fast_path_qtype_check) and no_tp
# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache(
@ -1177,6 +1199,13 @@ def llama_attention_forward_4_41_quantized(
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
if use_cache:
cache_kwargs = None
# [CompressKV]
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
@ -1227,6 +1256,13 @@ def llama_attention_forward_4_41_quantized(
attn_output = torch.matmul(attn_weights, repeated_value_states)
else:
cache_kwargs = None # Specific to RoPE models
# [CompressKV]
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
@ -1275,6 +1311,11 @@ def llama_attention_forward_4_41_quantized(
new_attn_mask = attention_mask[:, :, :, 0:kv_seq_len]
else:
new_attn_mask = attention_mask
# [CompressKV]
if use_compresskv:
new_attn_mask = get_compresskv_attn_mask(key_states,
new_attn_mask)
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
attn_weights = None
@ -1652,6 +1693,7 @@ def llama_attention_forward_4_38_quantized(
cache_position: Optional[torch.LongTensor] = None,
**kwargs
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
from ipex_llm.transformers.kv import DynamicCompressCache
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
@ -1668,6 +1710,10 @@ def llama_attention_forward_4_38_quantized(
enough_kv_room,
bsz * q_len,
llama_decoding_fast_path_qtype_check) and no_tp
# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
if decoding_fast_path:
hidden_states = hidden_states.view(1, -1)
tmp_cache_k, tmp_cache_v = init_kv_cache(
@ -1743,8 +1789,16 @@ def llama_attention_forward_4_38_quantized(
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
if use_cache:
cache_kwargs = None
# [CompressKV]
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
query_states, self.training):
import xe_addons
@ -1793,6 +1847,13 @@ def llama_attention_forward_4_38_quantized(
attn_output = torch.matmul(attn_weights, repeated_value_states)
else:
cache_kwargs = None # Specific to RoPE models
# [CompressKV]
if use_compresskv:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, cache_kwargs)
kv_seq_len = key_states.shape[-2]
@ -1841,6 +1902,11 @@ def llama_attention_forward_4_38_quantized(
new_attn_mask = attention_mask[:, :, kv_seq_len-q_len:kv_seq_len, 0:kv_seq_len]
else:
new_attn_mask = attention_mask
# [CompressKV]
if use_compresskv:
new_attn_mask = get_compresskv_attn_mask(key_states,
new_attn_mask)
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
attn_weights = None

View file

@ -47,7 +47,8 @@ from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compres
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
from ipex_llm.transformers.models.llama import repeat_kv
from ipex_llm.transformers.models.common import merge_qkv_base
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
DynamicCompressCache, DynamicCompressFp8Cache
from transformers.cache_utils import Cache
@ -79,6 +80,10 @@ def minicpm_attention_forward(
self.num_key_value_heads,
self.num_key_value_heads], dim=1)
# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
@ -94,7 +99,7 @@ def minicpm_attention_forward(
)
if past_key_value is not None:
if isinstance(past_key_value, DynamicCompressCache):
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
@ -107,10 +112,11 @@ def minicpm_attention_forward(
attn_weights = None
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
import xe_addons
if isinstance(past_key_value, DynamicCompressCache):
# [CompressKV]
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
elif isinstance(past_key_value, DynamicFp8Cache):
if use_quantizekv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
@ -118,14 +124,14 @@ def minicpm_attention_forward(
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
if use_quantizekv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
if use_quantizekv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
key_states = repeat_kv(key_states, self.num_key_value_groups)
@ -180,11 +186,15 @@ def minicpm_model_forward_wrapper(origin_forward):
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and use_compress_kv and not isinstance(past_key_values,
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif (not use_quantize_kv and not use_compress_kv
and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)

View file

@ -40,10 +40,11 @@ from torch import nn
from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
DynamicCompressCache, DynamicCompressFp8Cache
from typing import Optional, Tuple, List
from transformers.models.phi.modeling_phi import repeat_kv
@ -100,6 +101,7 @@ def attention_forward(
# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
@ -150,12 +152,9 @@ def attention_forward(
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
# [CompressKV]
if use_compresskv:
# print(attention_mask.shape)
context_len = key_states.size(2)
attention_mask = attention_mask[:, :, :, -context_len:]
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
import xe_addons
if isinstance(past_key_value,
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
if use_quantizekv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
@ -171,8 +170,7 @@ def attention_forward(
# attn_output = xe_addons.sdp_causal(query_states, key_states,
# value_states, attention_mask)
else:
if isinstance(past_key_value,
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
if use_quantizekv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
@ -262,11 +260,12 @@ def phi3_model_forward_wrapper(origin_model_forward):
if use_cache:
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
past_key_values = DynamicCompressCache.\
from_legacy_cache(past_key_values,
quantize_kv=use_quantize_kv)
if use_quantize_kv and not isinstance(past_key_values,
(DynamicFp8Cache, DynamicCompressCache)):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,

View file

@ -51,7 +51,8 @@ from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \
DynamicCompressCache, DynamicCompressFp8Cache
from ipex_llm.utils.common import invalidInputError
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
@ -122,11 +123,14 @@ def qwen2_model_forward(
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,
DynamicCompressCache)):
@ -312,10 +316,20 @@ def qwen2_model_forward_4_42(
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
self.config.num_attention_heads//self.config.num_key_value_heads)
)
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1])
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,
DynamicCompressCache)):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# ipex-llm changes end
@ -522,6 +536,7 @@ def qwen2_attention_forward(
# [CompressKV]
from ipex_llm.transformers.kv import DynamicCompressCache
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
qkv = self.qkv_proj(hidden_states)
@ -592,7 +607,7 @@ def qwen2_attention_forward(
import xe_addons
if use_compresskv:
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
if isinstance(past_key_value, DynamicFp8Cache):
if use_quantizekv:
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
@ -600,14 +615,14 @@ def qwen2_attention_forward(
attention_mask)
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
if use_quantizekv:
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
value_states, attention_mask)
else:
attn_output = xe_addons.sdp_causal(query_states, key_states,
value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
if use_quantizekv:
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads