llama 3.1/3.2 support compresskv (#12347)
* llama 3.1/3.2 support compresskv * update * fix transformers 4.45 error * fix style * fix typo * disable llama3.2 1b compresskv
This commit is contained in:
parent
d984c0672a
commit
f24352aef9
2 changed files with 50 additions and 5 deletions
|
|
@ -356,6 +356,22 @@ class DynamicCompressCache(DynamicCache):
|
|||
return 0
|
||||
return self.real_kv_len
|
||||
|
||||
@classmethod
|
||||
def from_legacy_cache(
|
||||
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
num_hidden_layers: int = None
|
||||
) -> "DynamicCache":
|
||||
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
||||
backward compatibility."""
|
||||
cache = cls(num_hidden_layers)
|
||||
if past_key_values is not None:
|
||||
for layer_idx in range(len(past_key_values)):
|
||||
key_states, value_states = past_key_values[layer_idx]
|
||||
invalidInputError(
|
||||
len(key_states) == 0 and len(value_states) == 0,
|
||||
"from_legacy_cache should be called with an empty kv cache.")
|
||||
return cache
|
||||
|
||||
|
||||
class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
|
||||
def update(
|
||||
|
|
|
|||
|
|
@ -50,7 +50,10 @@ from ipex_llm.transformers.models.common import attention_softmax
|
|||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
||||
from ipex_llm.transformers.models.utils import should_use_compresskv, \
|
||||
is_enough_kv_cache_room_4_36
|
||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \
|
||||
DynamicCompressFp8Cache
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
|
|
@ -83,11 +86,25 @@ def llama_model_forward(
|
|||
self.layers[0].mlp.down_proj, inputs,
|
||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||
)
|
||||
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||
isinstance(past_key_values, DynamicCompressCache)
|
||||
# disable llama3.2 1b for prefill performance and output quality
|
||||
use_compresskv = use_compresskv and self.config.hidden_size != 2048
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
if use_compresskv and not isinstance(past_key_values, DynamicCompressCache):
|
||||
if use_quantize_kv:
|
||||
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||
else:
|
||||
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
elif (
|
||||
not use_quantize_kv
|
||||
and not use_compresskv
|
||||
and not isinstance(past_key_values, DynamicNormalCache)
|
||||
):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
|
||||
# IPEX-LLM OPT end
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
|
@ -182,6 +199,9 @@ def llama_attention_forward(
|
|||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
# [CompressKV]
|
||||
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||
qkv = qkv.transpose(1, 2)
|
||||
|
|
@ -201,6 +221,15 @@ def llama_attention_forward(
|
|||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_value is not None:
|
||||
# [CompressKV]
|
||||
if use_compresskv:
|
||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||
q_len)
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx,
|
||||
query_states, attention_mask, self.num_key_value_groups,
|
||||
self.config, enough_kv_room, 256)
|
||||
else:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue