llama 3.1/3.2 support compresskv (#12347)

* llama 3.1/3.2 support compresskv

* update

* fix transformers 4.45 error

* fix style

* fix typo

* disable llama3.2 1b compresskv
This commit is contained in:
Yina Chen 2024-11-06 11:33:43 +02:00 committed by GitHub
parent d984c0672a
commit f24352aef9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 50 additions and 5 deletions

View file

@ -356,6 +356,22 @@ class DynamicCompressCache(DynamicCache):
return 0
return self.real_kv_len
@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
num_hidden_layers: int = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(num_hidden_layers)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
invalidInputError(
len(key_states) == 0 and len(value_states) == 0,
"from_legacy_cache should be called with an empty kv cache.")
return cache
class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
def update(

View file

@ -50,7 +50,10 @@ from ipex_llm.transformers.models.common import attention_softmax
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.utils import should_use_compresskv, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \
DynamicCompressFp8Cache
def llama_model_forward(
@ -83,11 +86,25 @@ def llama_model_forward(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
# disable llama3.2 1b for prefill performance and output quality
use_compresskv = use_compresskv and self.config.hidden_size != 2048
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compresskv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif (
not use_quantize_kv
and not use_compresskv
and not isinstance(past_key_values, DynamicNormalCache)
):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
# IPEX-LLM OPT end
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@ -182,6 +199,9 @@ def llama_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
@ -201,6 +221,15 @@ def llama_attention_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, 256)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)