optimize starcoder normal kv cache (#10642)

This commit is contained in:
Yishuo Wang 2024-04-03 15:27:02 +08:00 committed by GitHub
parent 3a9ab8f1ae
commit 702e686901
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 68 additions and 10 deletions

View file

@ -17,7 +17,10 @@
import torch
from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache
from .models.utils import (
init_fp8_kv_cache, append_fp8_kv_cache,
init_kv_cache, append_kv_cache
)
from typing import Optional, Dict, Tuple, Any
from transformers.cache_utils import DynamicCache
@ -63,3 +66,58 @@ class DynamicFp8Cache(DynamicCache):
self.value_cache[layer_idx] = v_cache
return self.key_cache[layer_idx], self.value_cache[layer_idx]
class DynamicNormalCache(DynamicCache):
KV_ALLOC_BLOCK_LENGTH = 256
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
new_layout=False, # useless, just keep same as DynamicFp8Cache
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_heads, seq_len, head_dim = key_states.shape
if layer_idx == 0:
if hasattr(self, "_seen_tokens"):
# 4.39 uses `_seen_tokens`
self._seen_tokens += seq_len
else:
# 4.37 uses `seen_tokens`
self.seen_tokens += seq_len
# Update the cache
if len(self.key_cache) <= layer_idx:
k_cache, v_cache = init_kv_cache(
batch_size, num_heads, head_dim,
0, key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH,
key_states.dtype, key_states.device
)
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
self.key_cache.append(k_cache)
self.value_cache.append(v_cache)
else:
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
kv_seq_len = k_cache.size(2) + key_states.size(2)
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
new_k_cache, new_v_cache = init_kv_cache(
batch_size, num_heads, head_dim,
k_cache.size(2), kv_seq_len + self.KV_ALLOC_BLOCK_LENGTH,
key_states.dtype, key_states.device
)
new_k_cache[...] = k_cache[...]
new_v_cache[...] = v_cache[...]
k_cache = new_k_cache
v_cache = new_v_cache
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
self.key_cache[layer_idx] = k_cache
self.value_cache[layer_idx] = v_cache
return self.key_cache[layer_idx], self.value_cache[layer_idx]

View file

@ -44,7 +44,7 @@ from ipex_llm.transformers.models.utils import (
use_quantize_kv_cache, restore_fp8_kv_cache,
apply_rotary_pos_emb_no_cache_xpu
)
from ipex_llm.transformers.kv import DynamicFp8Cache
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
from ipex_llm.utils.common.log4Error import invalidInputError
from typing import Optional, Tuple, List
@ -56,6 +56,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
def should_use_fuse_rope(self, hidden_states, position_ids):
use_fuse_rope = (
hidden_states.device.type == "xpu" and
hidden_states.numel() == hidden_states.size(-1) and
not (self.training and hidden_states.requires_grad) and
position_ids is not None
)
@ -130,12 +131,8 @@ def attention_forward(
"`past_key_value` cannot be None")
use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
if use_quantize_kv:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None, new_layout=True)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None, new_layout=True)
if use_quantize_kv and q_len == 1:
import linear_q4_0
@ -188,9 +185,12 @@ def model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids):
if not isinstance(past_key_values, DynamicFp8Cache):
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
return Starcoder2Model.forward(
self=self,
input_ids=input_ids,