optimize starcoder normal kv cache (#10642)
This commit is contained in:
parent
3a9ab8f1ae
commit
702e686901
2 changed files with 68 additions and 10 deletions
|
|
@ -17,7 +17,10 @@
|
|||
|
||||
import torch
|
||||
|
||||
from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache
|
||||
from .models.utils import (
|
||||
init_fp8_kv_cache, append_fp8_kv_cache,
|
||||
init_kv_cache, append_kv_cache
|
||||
)
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
|
|
@ -63,3 +66,58 @@ class DynamicFp8Cache(DynamicCache):
|
|||
self.value_cache[layer_idx] = v_cache
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
|
||||
class DynamicNormalCache(DynamicCache):
|
||||
KV_ALLOC_BLOCK_LENGTH = 256
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
new_layout=False, # useless, just keep same as DynamicFp8Cache
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = key_states.shape
|
||||
|
||||
if layer_idx == 0:
|
||||
if hasattr(self, "_seen_tokens"):
|
||||
# 4.39 uses `_seen_tokens`
|
||||
self._seen_tokens += seq_len
|
||||
else:
|
||||
# 4.37 uses `seen_tokens`
|
||||
self.seen_tokens += seq_len
|
||||
|
||||
# Update the cache
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
k_cache, v_cache = init_kv_cache(
|
||||
batch_size, num_heads, head_dim,
|
||||
0, key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH,
|
||||
key_states.dtype, key_states.device
|
||||
)
|
||||
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
|
||||
self.key_cache.append(k_cache)
|
||||
self.value_cache.append(v_cache)
|
||||
else:
|
||||
k_cache = self.key_cache[layer_idx]
|
||||
v_cache = self.value_cache[layer_idx]
|
||||
|
||||
kv_seq_len = k_cache.size(2) + key_states.size(2)
|
||||
if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
|
||||
new_k_cache, new_v_cache = init_kv_cache(
|
||||
batch_size, num_heads, head_dim,
|
||||
k_cache.size(2), kv_seq_len + self.KV_ALLOC_BLOCK_LENGTH,
|
||||
key_states.dtype, key_states.device
|
||||
)
|
||||
new_k_cache[...] = k_cache[...]
|
||||
new_v_cache[...] = v_cache[...]
|
||||
k_cache = new_k_cache
|
||||
v_cache = new_v_cache
|
||||
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
||||
self.key_cache[layer_idx] = k_cache
|
||||
self.value_cache[layer_idx] = v_cache
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ from ipex_llm.transformers.models.utils import (
|
|||
use_quantize_kv_cache, restore_fp8_kv_cache,
|
||||
apply_rotary_pos_emb_no_cache_xpu
|
||||
)
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache
|
||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
|
||||
from ipex_llm.utils.common.log4Error import invalidInputError
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
|
@ -56,6 +56,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
|
|||
def should_use_fuse_rope(self, hidden_states, position_ids):
|
||||
use_fuse_rope = (
|
||||
hidden_states.device.type == "xpu" and
|
||||
hidden_states.numel() == hidden_states.size(-1) and
|
||||
not (self.training and hidden_states.requires_grad) and
|
||||
position_ids is not None
|
||||
)
|
||||
|
|
@ -130,12 +131,8 @@ def attention_forward(
|
|||
"`past_key_value` cannot be None")
|
||||
use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
|
||||
|
||||
if use_quantize_kv:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None, new_layout=True)
|
||||
else:
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None)
|
||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||
self.layer_idx, None, new_layout=True)
|
||||
|
||||
if use_quantize_kv and q_len == 1:
|
||||
import linear_q4_0
|
||||
|
|
@ -188,9 +185,12 @@ def model_forward(
|
|||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids):
|
||||
if not isinstance(past_key_values, DynamicFp8Cache):
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||
return Starcoder2Model.forward(
|
||||
self=self,
|
||||
input_ids=input_ids,
|
||||
|
|
|
|||
Loading…
Reference in a new issue