Support compress KV with quantize KV (#11812)
* update llama * support llama 4.41 * fix style * support minicpm * support qwen2 * support minicpm & update * support chatglm4 * support chatglm * remove print * add DynamicCompressFp8Cache & support qwen * support llama * support minicpm phi3 * update chatglm2/4 * small fix & support qwen 4.42 * remove print
This commit is contained in:
parent
6841a9ac8f
commit
3cd4e87168
7 changed files with 298 additions and 147 deletions
|
|
@ -218,8 +218,6 @@ class DynamicCompressCache(DynamicCache):
|
||||||
def __init__(self, quant_kv=False, *args, **kwargs):
|
def __init__(self, quant_kv=False, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.real_kv_len = 0
|
self.real_kv_len = 0
|
||||||
self.quant_kv = quant_kv
|
|
||||||
self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache
|
|
||||||
|
|
||||||
def update_seen_tokens(self, layer_idx, q_len):
|
def update_seen_tokens(self, layer_idx, q_len):
|
||||||
if layer_idx == 0:
|
if layer_idx == 0:
|
||||||
|
|
@ -266,46 +264,33 @@ class DynamicCompressCache(DynamicCache):
|
||||||
value_states=value_states,
|
value_states=value_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
num_key_value_groups=num_key_value_groups)
|
num_key_value_groups=num_key_value_groups)
|
||||||
self.key_cache.append(key_states_compress)
|
|
||||||
self.value_cache.append(value_states_compress)
|
|
||||||
|
|
||||||
if not self.quant_kv:
|
k_cache_compressed, v_cache_compressed = init_kv_cache(
|
||||||
k_cache_compressed, v_cache_compressed = init_kv_cache(
|
bsz, num_heads, head_dim,
|
||||||
bsz, num_heads, head_dim,
|
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
key_states.dtype, key_states.device
|
||||||
key_states.dtype, key_states.device
|
)
|
||||||
)
|
k_cache_compressed, v_cache_compressed = append_kv_cache(
|
||||||
else:
|
|
||||||
k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
|
|
||||||
bsz, num_heads, seq_len, head_dim,
|
|
||||||
device=key_states.device,
|
|
||||||
)
|
|
||||||
k_cache_compressed, v_cache_compressed = self.append_kv_func(
|
|
||||||
k_cache_compressed, v_cache_compressed,
|
k_cache_compressed, v_cache_compressed,
|
||||||
key_states_compress, value_states_compress)
|
key_states_compress, value_states_compress)
|
||||||
self.key_cache[layer_idx] = k_cache_compressed
|
self.key_cache.append(k_cache_compressed)
|
||||||
self.value_cache[layer_idx] = v_cache_compressed
|
self.value_cache.append(v_cache_compressed)
|
||||||
|
|
||||||
if key_states.stride(2) != head_dim:
|
if key_states.stride(2) != head_dim:
|
||||||
if not self.quant_kv:
|
k_cache, v_cache = init_kv_cache(
|
||||||
k_cache, v_cache = init_kv_cache(
|
bsz, num_heads, head_dim,
|
||||||
bsz, num_heads, head_dim,
|
0, key_states.size(2),
|
||||||
0, key_states.size(2),
|
key_states.dtype, key_states.device
|
||||||
key_states.dtype, key_states.device
|
)
|
||||||
)
|
k_cache, v_cache = append_kv_cache(k_cache, v_cache,
|
||||||
else:
|
key_states, value_states)
|
||||||
k_cache, v_cache = init_fp8_kv_cache(
|
|
||||||
bsz, num_heads, 0, head_dim, key_states.device
|
|
||||||
)
|
|
||||||
k_cache, v_cache = self.append_kv_func(k_cache, v_cache,
|
|
||||||
key_states, value_states)
|
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
else:
|
else:
|
||||||
return key_states, value_states
|
return key_states, value_states
|
||||||
else:
|
else:
|
||||||
cache_k = self.key_cache[layer_idx]
|
cache_k = self.key_cache[layer_idx]
|
||||||
cache_v = self.value_cache[layer_idx]
|
cache_v = self.value_cache[layer_idx]
|
||||||
if not enough_kv_room and not self.quant_kv:
|
if not enough_kv_room:
|
||||||
# allocate new
|
# allocate new
|
||||||
new_c_k, new_c_v = extend_kv_cache(
|
new_c_k, new_c_v = extend_kv_cache(
|
||||||
bsz,
|
bsz,
|
||||||
|
|
@ -321,10 +306,10 @@ class DynamicCompressCache(DynamicCache):
|
||||||
cache_k = new_c_k
|
cache_k = new_c_k
|
||||||
cache_v = new_c_v
|
cache_v = new_c_v
|
||||||
|
|
||||||
key_states, value_states = self.append_kv_func(cache_k,
|
key_states, value_states = append_kv_cache(cache_k,
|
||||||
cache_v,
|
cache_v,
|
||||||
key_states,
|
key_states,
|
||||||
value_states)
|
value_states)
|
||||||
|
|
||||||
# update past_key_value
|
# update past_key_value
|
||||||
self.key_cache[layer_idx] = key_states
|
self.key_cache[layer_idx] = key_states
|
||||||
|
|
@ -339,13 +324,74 @@ class DynamicCompressCache(DynamicCache):
|
||||||
return 0
|
return 0
|
||||||
return self.real_kv_len
|
return self.real_kv_len
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
|
||||||
quantize_kv: Optional[bool] = False) -> "DynamicCache":
|
def update(
|
||||||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
self,
|
||||||
cache = cls(quantize_kv)
|
key_states: torch.Tensor,
|
||||||
if past_key_values is not None:
|
value_states: torch.Tensor,
|
||||||
for layer_idx in range(len(past_key_values)):
|
layer_idx: int,
|
||||||
key_states, value_states = past_key_values[layer_idx]
|
query_states: torch.Tensor,
|
||||||
cache.update(key_states, value_states, layer_idx)
|
attention_mask: torch.Tensor,
|
||||||
return cache
|
num_key_value_groups: int,
|
||||||
|
attn_config: Dict[str, Any],
|
||||||
|
enough_kv_room: bool,
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
bsz, num_heads, seq_len, head_dim = key_states.shape
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
if hasattr(self, "_seen_tokens"):
|
||||||
|
# 4.39 uses `_seen_tokens`
|
||||||
|
self._seen_tokens += seq_len
|
||||||
|
else:
|
||||||
|
# 4.37 uses `seen_tokens`
|
||||||
|
self.seen_tokens += seq_len
|
||||||
|
self.real_kv_len += seq_len
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
# First token, compress kv cache
|
||||||
|
key_states_compress, value_states_compress = compress_kv(
|
||||||
|
attn_config=attn_config,
|
||||||
|
key_states=key_states,
|
||||||
|
query_states=query_states,
|
||||||
|
value_states=value_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
num_key_value_groups=num_key_value_groups)
|
||||||
|
|
||||||
|
k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
|
||||||
|
bsz, num_heads, seq_len, head_dim,
|
||||||
|
device=key_states.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
k_cache_compressed, v_cache_compressed = append_fp8_kv_cache(
|
||||||
|
k_cache_compressed, v_cache_compressed,
|
||||||
|
key_states_compress, value_states_compress)
|
||||||
|
self.key_cache.append(k_cache_compressed)
|
||||||
|
self.value_cache.append(v_cache_compressed)
|
||||||
|
|
||||||
|
if key_states.stride(2) != head_dim:
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
bsz, num_heads, 0, head_dim, key_states.device
|
||||||
|
)
|
||||||
|
k_cache, v_cache = append_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
return k_cache, v_cache
|
||||||
|
else:
|
||||||
|
return key_states, value_states
|
||||||
|
else:
|
||||||
|
cache_k = self.key_cache[layer_idx]
|
||||||
|
cache_v = self.value_cache[layer_idx]
|
||||||
|
|
||||||
|
key_states, value_states = append_fp8_kv_cache(cache_k,
|
||||||
|
cache_v,
|
||||||
|
key_states,
|
||||||
|
value_states)
|
||||||
|
|
||||||
|
# update past_key_value
|
||||||
|
self.key_cache[layer_idx] = key_states
|
||||||
|
self.value_cache[layer_idx] = value_states
|
||||||
|
|
||||||
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py
|
# https://huggingface.co/THUDM/chatglm2-6b/blob/8eb45c842594b8473f291d0f94e7bbe86ffc67d8/modeling_chatglm.py
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
@ -27,7 +28,9 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, u
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
|
||||||
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
|
use_sdp_causal, should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.kv import DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
@ -90,9 +93,12 @@ def chatglm2_model_forward(
|
||||||
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
||||||
input_ids)
|
input_ids)
|
||||||
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if full_attention_mask is None:
|
if full_attention_mask is None:
|
||||||
if (attention_mask is not None and not attention_mask.all()) or (
|
if (attention_mask is not None and not attention_mask.all()) or (
|
||||||
|
|
@ -279,15 +285,9 @@ def chatglm2_attention_forward(
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
||||||
if use_quantize_kv or (not use_compresskv):
|
|
||||||
key_states, value_states = update_past_key_value(
|
# [CompressKV]
|
||||||
past_key_value, key_states, value_states,
|
if use_compresskv:
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
|
||||||
)
|
|
||||||
# past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
|
|
||||||
past_key_value = (key_states.permute(2, 0, 1, 3),
|
|
||||||
value_states.permute(2, 0, 1, 3)) if use_cache else None
|
|
||||||
else:
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
|
@ -296,8 +296,16 @@ def chatglm2_attention_forward(
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = past_key_value.update(
|
||||||
key_states, value_states, self.layer_number - 1,
|
key_states, value_states, self.layer_number - 1,
|
||||||
query_states, attention_mask, n_head // n_kv_head,
|
query_states, attention_mask, n_head // n_kv_head,
|
||||||
self.config, enough_kv_room, 256
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
key_states, value_states = update_past_key_value(
|
||||||
|
past_key_value, key_states, value_states,
|
||||||
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
)
|
||||||
|
# past_key_value: [bsz, n_kv_head, seq_len, head_dim] -> [seq_len, bsz, n_kv_head, head_dim]
|
||||||
|
past_key_value = (key_states.permute(2, 0, 1, 3),
|
||||||
|
value_states.permute(2, 0, 1, 3)) if use_cache else None
|
||||||
|
|
||||||
# IPEX-LLM OPT: sdp
|
# IPEX-LLM OPT: sdp
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/configuration_chatglm.py
|
# https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/configuration_chatglm.py
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, update_past_key_value
|
||||||
|
|
@ -25,10 +26,12 @@ from ipex_llm.transformers.models.utils import use_quantize_kv_cache, use_sdp, \
|
||||||
get_compresskv_attn_mask
|
get_compresskv_attn_mask
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, apply_rotary_pos_emb
|
||||||
from ipex_llm.transformers.models.chatglm2 import repeat_kv
|
from ipex_llm.transformers.models.chatglm2 import repeat_kv
|
||||||
from ipex_llm.transformers.kv import DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicCompressCache, DynamicCompressFp8Cache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
def chatglm4_model_forward(
|
def chatglm4_model_forward(
|
||||||
self,
|
self,
|
||||||
|
|
@ -54,9 +57,12 @@ def chatglm4_model_forward(
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
|
||||||
inputs)
|
inputs)
|
||||||
if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
|
|
@ -201,7 +207,19 @@ def chatglm4_attention_forward(
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
||||||
|
|
||||||
if use_quantize_kv or (not use_compresskv):
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
self.layer_number - 1,
|
||||||
|
q_len)
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_number - 1,
|
||||||
|
query_states, attention_mask, n_head // n_kv_head,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
|
)
|
||||||
|
else:
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
|
@ -214,30 +232,19 @@ def chatglm4_attention_forward(
|
||||||
past_key_value = (key_states, value_states)
|
past_key_value = (key_states, value_states)
|
||||||
else:
|
else:
|
||||||
past_key_value = None
|
past_key_value = None
|
||||||
else:
|
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
self.config = self.config if hasattr(self, "config") else PretrainedConfig()
|
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
|
||||||
self.layer_number - 1,
|
|
||||||
q_len)
|
|
||||||
key_states, value_states = past_key_value.update(
|
|
||||||
key_states, value_states, self.layer_number - 1,
|
|
||||||
query_states, attention_mask, n_head // n_kv_head,
|
|
||||||
self.config, enough_kv_room, 256
|
|
||||||
)
|
|
||||||
|
|
||||||
# IPEX-LLM OPT: sdp
|
# IPEX-LLM OPT: sdp
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
if use_sdp(q_len, kv_seq_len, head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
if use_compresskv:
|
||||||
|
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if use_compresskv:
|
|
||||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
|
|
|
||||||
|
|
@ -120,19 +120,25 @@ def llama_model_forward_4_36(
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \
|
||||||
|
DynamicCompressFp8Cache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
use_quantize = use_quantize_kv_cache(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.layers[0].mlp.up_proj, input,
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
if should_use_compresskv(input, input.shape[1]):
|
||||||
elif should_use_compresskv(input, input.shape[1]):
|
|
||||||
# if use quantize kv, compress kv will be ignored now
|
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
if use_quantize:
|
||||||
past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||||
|
past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
past_key_values)
|
||||||
|
elif use_quantize:
|
||||||
|
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
||||||
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return llama_model_forward_4_36_internal(
|
return llama_model_forward_4_36_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
@ -160,19 +166,25 @@ def llama_model_forward_4_38(
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \
|
||||||
|
DynamicCompressFp8Cache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
use_quantize = use_quantize_kv_cache(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.layers[0].mlp.up_proj, input,
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
if should_use_compresskv(input, input.shape[1]):
|
||||||
elif should_use_compresskv(input, input.shape[1]):
|
|
||||||
# if use quantize kv, compress kv will be ignored now
|
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
if use_quantize:
|
||||||
past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||||
|
past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
past_key_values)
|
||||||
|
elif use_quantize:
|
||||||
|
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
||||||
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return llama_model_forward_4_38_internal(
|
return llama_model_forward_4_38_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
@ -201,19 +213,25 @@ def llama_model_forward_4_41(
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicCompressCache, \
|
||||||
|
DynamicCompressFp8Cache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv_cache(self.layers[0].mlp.up_proj, input,
|
use_quantize = use_quantize_kv_cache(
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads):
|
self.layers[0].mlp.up_proj, input,
|
||||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
if should_use_compresskv(input, input.shape[1]):
|
||||||
elif should_use_compresskv(input, input.shape[1]):
|
|
||||||
# if use quantize kv, compress kv will be ignored now
|
|
||||||
if not isinstance(past_key_values, DynamicCompressCache):
|
if not isinstance(past_key_values, DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(
|
if use_quantize:
|
||||||
past_key_values)
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(
|
||||||
|
past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(
|
||||||
|
past_key_values)
|
||||||
|
elif use_quantize:
|
||||||
|
if not isinstance(past_key_values, (DynamicFp8Cache, DynamicCompressCache)):
|
||||||
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
return llama_model_forward_4_41_internal(
|
return llama_model_forward_4_41_internal(
|
||||||
self=self,
|
self=self,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|
@ -1086,6 +1104,7 @@ def llama_attention_forward_4_41_quantized(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -1102,6 +1121,9 @@ def llama_attention_forward_4_41_quantized(
|
||||||
enough_kv_room,
|
enough_kv_room,
|
||||||
bsz * q_len,
|
bsz * q_len,
|
||||||
llama_decoding_fast_path_qtype_check) and no_tp
|
llama_decoding_fast_path_qtype_check) and no_tp
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
if decoding_fast_path:
|
if decoding_fast_path:
|
||||||
hidden_states = hidden_states.view(1, -1)
|
hidden_states = hidden_states.view(1, -1)
|
||||||
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
||||||
|
|
@ -1177,8 +1199,15 @@ def llama_attention_forward_4_41_quantized(
|
||||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
# [CompressKV]
|
||||||
self.layer_idx, cache_kwargs)
|
if use_compresskv:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
||||||
query_states, self.training):
|
query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
@ -1227,8 +1256,15 @@ def llama_attention_forward_4_41_quantized(
|
||||||
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||||
else:
|
else:
|
||||||
cache_kwargs = None # Specific to RoPE models
|
cache_kwargs = None # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
# [CompressKV]
|
||||||
self.layer_idx, cache_kwargs)
|
if use_compresskv:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
@ -1275,6 +1311,11 @@ def llama_attention_forward_4_41_quantized(
|
||||||
new_attn_mask = attention_mask[:, :, :, 0:kv_seq_len]
|
new_attn_mask = attention_mask[:, :, :, 0:kv_seq_len]
|
||||||
else:
|
else:
|
||||||
new_attn_mask = attention_mask
|
new_attn_mask = attention_mask
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
new_attn_mask = get_compresskv_attn_mask(key_states,
|
||||||
|
new_attn_mask)
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
|
|
@ -1652,6 +1693,7 @@ def llama_attention_forward_4_38_quantized(
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
|
||||||
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
if "padding_mask" in kwargs:
|
if "padding_mask" in kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
|
|
@ -1668,6 +1710,10 @@ def llama_attention_forward_4_38_quantized(
|
||||||
enough_kv_room,
|
enough_kv_room,
|
||||||
bsz * q_len,
|
bsz * q_len,
|
||||||
llama_decoding_fast_path_qtype_check) and no_tp
|
llama_decoding_fast_path_qtype_check) and no_tp
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
if decoding_fast_path:
|
if decoding_fast_path:
|
||||||
hidden_states = hidden_states.view(1, -1)
|
hidden_states = hidden_states.view(1, -1)
|
||||||
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
tmp_cache_k, tmp_cache_v = init_kv_cache(
|
||||||
|
|
@ -1743,8 +1789,16 @@ def llama_attention_forward_4_38_quantized(
|
||||||
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
repeated_value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
cache_kwargs = None
|
cache_kwargs = None
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
# [CompressKV]
|
||||||
self.layer_idx, cache_kwargs)
|
if use_compresskv:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
if use_cache and use_sdp_causal(q_len, kv_seq_len, self.head_dim,
|
||||||
query_states, self.training):
|
query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
|
|
@ -1793,8 +1847,15 @@ def llama_attention_forward_4_38_quantized(
|
||||||
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
attn_output = torch.matmul(attn_weights, repeated_value_states)
|
||||||
else:
|
else:
|
||||||
cache_kwargs = None # Specific to RoPE models
|
cache_kwargs = None # Specific to RoPE models
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
# [CompressKV]
|
||||||
self.layer_idx, cache_kwargs)
|
if use_compresskv:
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
|
self.layer_idx, cache_kwargs)
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
if not use_sdp_fp8(q_len, key_states.shape[2], query_states):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
|
|
@ -1841,6 +1902,11 @@ def llama_attention_forward_4_38_quantized(
|
||||||
new_attn_mask = attention_mask[:, :, kv_seq_len-q_len:kv_seq_len, 0:kv_seq_len]
|
new_attn_mask = attention_mask[:, :, kv_seq_len-q_len:kv_seq_len, 0:kv_seq_len]
|
||||||
else:
|
else:
|
||||||
new_attn_mask = attention_mask
|
new_attn_mask = attention_mask
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
new_attn_mask = get_compresskv_attn_mask(key_states,
|
||||||
|
new_attn_mask)
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, new_attn_mask)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,8 @@ from ipex_llm.transformers.models.utils import restore_fp8_kv_cache, get_compres
|
||||||
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_compresskv, should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.llama import repeat_kv
|
from ipex_llm.transformers.models.llama import repeat_kv
|
||||||
from ipex_llm.transformers.models.common import merge_qkv_base
|
from ipex_llm.transformers.models.common import merge_qkv_base
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
|
||||||
|
DynamicCompressCache, DynamicCompressFp8Cache
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -79,6 +80,10 @@ def minicpm_attention_forward(
|
||||||
self.num_key_value_heads,
|
self.num_key_value_heads,
|
||||||
self.num_key_value_heads], dim=1)
|
self.num_key_value_heads], dim=1)
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
|
@ -94,7 +99,7 @@ def minicpm_attention_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
if isinstance(past_key_value, DynamicCompressCache):
|
if use_compresskv:
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, q_len)
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = past_key_value.update(
|
||||||
key_states, value_states, self.layer_idx,
|
key_states, value_states, self.layer_idx,
|
||||||
|
|
@ -107,10 +112,11 @@ def minicpm_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value, DynamicCompressCache):
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||||
attn_output = xe_addons.sdp(query_states, key_states, value_states, attention_mask)
|
|
||||||
elif isinstance(past_key_value, DynamicFp8Cache):
|
if use_quantizekv:
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
|
|
@ -118,14 +124,14 @@ def minicpm_attention_forward(
|
||||||
attention_mask)
|
attention_mask)
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if use_quantizekv:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
value_states, attention_mask)
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
value_states, attention_mask)
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if use_quantizekv:
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
|
@ -180,11 +186,15 @@ def minicpm_model_forward_wrapper(origin_forward):
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
|
DynamicCompressCache):
|
||||||
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||||
|
DynamicCompressCache)):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif not use_quantize_kv and use_compress_kv and not isinstance(past_key_values,
|
|
||||||
DynamicCompressCache):
|
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
|
||||||
elif (not use_quantize_kv and not use_compress_kv
|
elif (not use_quantize_kv and not use_compress_kv
|
||||||
and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))):
|
and not isinstance(past_key_values, (DynamicNormalCache, DynamicCompressCache))):
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -40,10 +40,11 @@ from torch import nn
|
||||||
from ipex_llm.transformers.models.common import attention_softmax
|
from ipex_llm.transformers.models.common import attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal, get_compresskv_attn_mask
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
|
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, \
|
||||||
|
DynamicCompressCache, DynamicCompressFp8Cache
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from transformers.models.phi.modeling_phi import repeat_kv
|
from transformers.models.phi.modeling_phi import repeat_kv
|
||||||
|
|
@ -100,6 +101,7 @@ def attention_forward(
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
|
||||||
|
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||||
|
|
@ -150,12 +152,9 @@ def attention_forward(
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
# print(attention_mask.shape)
|
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||||
context_len = key_states.size(2)
|
|
||||||
attention_mask = attention_mask[:, :, :, -context_len:]
|
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value,
|
if use_quantizekv:
|
||||||
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
|
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
|
|
@ -171,8 +170,7 @@ def attention_forward(
|
||||||
# attn_output = xe_addons.sdp_causal(query_states, key_states,
|
# attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
# value_states, attention_mask)
|
# value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if isinstance(past_key_value,
|
if use_quantizekv:
|
||||||
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
|
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
|
@ -262,11 +260,12 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
past_key_values = DynamicCompressCache.\
|
if use_quantize_kv:
|
||||||
from_legacy_cache(past_key_values,
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
quantize_kv=use_quantize_kv)
|
else:
|
||||||
if use_quantize_kv and not isinstance(past_key_values,
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
(DynamicFp8Cache, DynamicCompressCache)):
|
if use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||||
|
DynamicCompressCache)):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
(DynamicNormalCache,
|
(DynamicNormalCache,
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,8 @@ from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache, \
|
||||||
should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
|
should_use_compresskv, is_enough_kv_cache_room_4_36, get_compresskv_attn_mask
|
||||||
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache, \
|
||||||
|
DynamicCompressCache, DynamicCompressFp8Cache
|
||||||
from ipex_llm.utils.common import invalidInputError
|
from ipex_llm.utils.common import invalidInputError
|
||||||
|
|
||||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2MLP
|
||||||
|
|
@ -122,11 +123,14 @@ def qwen2_model_forward(
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||||
|
DynamicCompressCache)):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif not use_quantize_kv and use_compress_kv and not isinstance(past_key_values,
|
|
||||||
DynamicCompressCache):
|
|
||||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
|
||||||
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
(DynamicNormalCache,
|
(DynamicNormalCache,
|
||||||
DynamicCompressCache)):
|
DynamicCompressCache)):
|
||||||
|
|
@ -312,10 +316,20 @@ def qwen2_model_forward_4_42(
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
self.config.num_attention_heads//self.config.num_key_value_heads)
|
||||||
)
|
)
|
||||||
|
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1])
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
elif use_quantize_kv and not isinstance(past_key_values, (DynamicFp8Cache,
|
||||||
|
DynamicCompressCache)):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
elif not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
|
(DynamicNormalCache,
|
||||||
|
DynamicCompressCache)):
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
# ipex-llm changes end
|
# ipex-llm changes end
|
||||||
|
|
||||||
|
|
@ -522,6 +536,7 @@ def qwen2_attention_forward(
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
from ipex_llm.transformers.kv import DynamicCompressCache
|
from ipex_llm.transformers.kv import DynamicCompressCache
|
||||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
use_quantizekv = isinstance(past_key_value, DynamicFp8Cache)
|
||||||
|
|
||||||
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
|
if hasattr(self, 'qkv_proj') and self.qkv_proj is not None:
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
|
|
@ -592,7 +607,7 @@ def qwen2_attention_forward(
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
attention_mask = get_compresskv_attn_mask(key_states, attention_mask)
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if use_quantizekv:
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
|
|
@ -600,14 +615,14 @@ def qwen2_attention_forward(
|
||||||
attention_mask)
|
attention_mask)
|
||||||
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
elif use_sdp_causal(q_len, kv_seq_len, self.head_dim, query_states, self.training):
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if use_quantizekv:
|
||||||
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
attn_output = xe_addons.sdp_fp8_causal(query_states, key_states,
|
||||||
value_states, attention_mask)
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
value_states, attention_mask)
|
value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if use_quantizekv:
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue