From 702e68690193020c0ee35100ef141eace314023b Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Wed, 3 Apr 2024 15:27:02 +0800 Subject: [PATCH] optimize starcoder normal kv cache (#10642) --- python/llm/src/ipex_llm/transformers/kv.py | 60 ++++++++++++++++++- .../transformers/models/starcoder2.py | 18 +++--- 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index f0b4d657..18acf17b 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -17,7 +17,10 @@ import torch -from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache +from .models.utils import ( + init_fp8_kv_cache, append_fp8_kv_cache, + init_kv_cache, append_kv_cache +) from typing import Optional, Dict, Tuple, Any from transformers.cache_utils import DynamicCache @@ -63,3 +66,58 @@ class DynamicFp8Cache(DynamicCache): self.value_cache[layer_idx] = v_cache return self.key_cache[layer_idx], self.value_cache[layer_idx] + + +class DynamicNormalCache(DynamicCache): + KV_ALLOC_BLOCK_LENGTH = 256 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]]=None, + new_layout=False, # useless, just keep same as DynamicFp8Cache + ) -> Tuple[torch.Tensor, torch.Tensor]: + + batch_size, num_heads, seq_len, head_dim = key_states.shape + + if layer_idx == 0: + if hasattr(self, "_seen_tokens"): + # 4.39 uses `_seen_tokens` + self._seen_tokens += seq_len + else: + # 4.37 uses `seen_tokens` + self.seen_tokens += seq_len + + # Update the cache + if len(self.key_cache) <= layer_idx: + k_cache, v_cache = init_kv_cache( + batch_size, num_heads, head_dim, + 0, key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH, + key_states.dtype, key_states.device + ) + k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states) + + self.key_cache.append(k_cache) + self.value_cache.append(v_cache) + else: + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + + kv_seq_len = k_cache.size(2) + key_states.size(2) + if k_cache.stride(1) < kv_seq_len * k_cache.size(3): + new_k_cache, new_v_cache = init_kv_cache( + batch_size, num_heads, head_dim, + k_cache.size(2), kv_seq_len + self.KV_ALLOC_BLOCK_LENGTH, + key_states.dtype, key_states.device + ) + new_k_cache[...] = k_cache[...] + new_v_cache[...] = v_cache[...] + k_cache = new_k_cache + v_cache = new_v_cache + k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states) + self.key_cache[layer_idx] = k_cache + self.value_cache[layer_idx] = v_cache + + return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/python/llm/src/ipex_llm/transformers/models/starcoder2.py b/python/llm/src/ipex_llm/transformers/models/starcoder2.py index de48d11b..c8e23eac 100644 --- a/python/llm/src/ipex_llm/transformers/models/starcoder2.py +++ b/python/llm/src/ipex_llm/transformers/models/starcoder2.py @@ -44,7 +44,7 @@ from ipex_llm.transformers.models.utils import ( use_quantize_kv_cache, restore_fp8_kv_cache, apply_rotary_pos_emb_no_cache_xpu ) -from ipex_llm.transformers.kv import DynamicFp8Cache +from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache from ipex_llm.utils.common.log4Error import invalidInputError from typing import Optional, Tuple, List @@ -56,6 +56,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model, def should_use_fuse_rope(self, hidden_states, position_ids): use_fuse_rope = ( hidden_states.device.type == "xpu" and + hidden_states.numel() == hidden_states.size(-1) and not (self.training and hidden_states.requires_grad) and position_ids is not None ) @@ -130,12 +131,8 @@ def attention_forward( "`past_key_value` cannot be None") use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states) - if use_quantize_kv: - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None, new_layout=True) - else: - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None) + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None, new_layout=True) if use_quantize_kv and q_len == 1: import linear_q4_0 @@ -188,9 +185,12 @@ def model_forward( return_dict: Optional[bool] = None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - if use_cache and use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids): - if not isinstance(past_key_values, DynamicFp8Cache): + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids) + if use_cache: + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) return Starcoder2Model.forward( self=self, input_ids=input_ids,