llama 3.1/3.2 support compresskv (#12347)
* llama 3.1/3.2 support compresskv * update * fix transformers 4.45 error * fix style * fix typo * disable llama3.2 1b compresskv
This commit is contained in:
parent
d984c0672a
commit
f24352aef9
2 changed files with 50 additions and 5 deletions
|
|
@ -356,6 +356,22 @@ class DynamicCompressCache(DynamicCache):
|
||||||
return 0
|
return 0
|
||||||
return self.real_kv_len
|
return self.real_kv_len
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_legacy_cache(
|
||||||
|
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
num_hidden_layers: int = None
|
||||||
|
) -> "DynamicCache":
|
||||||
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
||||||
|
backward compatibility."""
|
||||||
|
cache = cls(num_hidden_layers)
|
||||||
|
if past_key_values is not None:
|
||||||
|
for layer_idx in range(len(past_key_values)):
|
||||||
|
key_states, value_states = past_key_values[layer_idx]
|
||||||
|
invalidInputError(
|
||||||
|
len(key_states) == 0 and len(value_states) == 0,
|
||||||
|
"from_legacy_cache should be called with an empty kv cache.")
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
|
class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
|
||||||
def update(
|
def update(
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,10 @@ from ipex_llm.transformers.models.common import attention_softmax
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
from ipex_llm.transformers.models.utils import should_use_fuse_rope
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.models.utils import should_use_compresskv, \
|
||||||
|
is_enough_kv_cache_room_4_36
|
||||||
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \
|
||||||
|
DynamicCompressFp8Cache
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
|
|
@ -83,11 +86,25 @@ def llama_model_forward(
|
||||||
self.layers[0].mlp.down_proj, inputs,
|
self.layers[0].mlp.down_proj, inputs,
|
||||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
# disable llama3.2 1b for prefill performance and output quality
|
||||||
|
use_compresskv = use_compresskv and self.config.hidden_size != 2048
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_compresskv and not isinstance(past_key_values, DynamicCompressCache):
|
||||||
|
if use_quantize_kv:
|
||||||
|
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
else:
|
||||||
|
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
|
||||||
|
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
elif (
|
||||||
|
not use_quantize_kv
|
||||||
|
and not use_compresskv
|
||||||
|
and not isinstance(past_key_values, DynamicNormalCache)
|
||||||
|
):
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
# IPEX-LLM OPT end
|
# IPEX-LLM OPT end
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
@ -182,6 +199,9 @@ def llama_attention_forward(
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||||
qkv = qkv.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
|
|
@ -201,6 +221,15 @@ def llama_attention_forward(
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
|
||||||
|
q_len)
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, 256)
|
||||||
|
else:
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, None)
|
self.layer_idx, None)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue