optimize starcoder normal kv cache (#10642)
This commit is contained in:
		
							parent
							
								
									3a9ab8f1ae
								
							
						
					
					
						commit
						702e686901
					
				
					 2 changed files with 68 additions and 10 deletions
				
			
		| 
						 | 
				
			
			@ -17,7 +17,10 @@
 | 
			
		|||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache
 | 
			
		||||
from .models.utils import (
 | 
			
		||||
    init_fp8_kv_cache, append_fp8_kv_cache,
 | 
			
		||||
    init_kv_cache, append_kv_cache
 | 
			
		||||
)
 | 
			
		||||
from typing import Optional, Dict, Tuple, Any
 | 
			
		||||
from transformers.cache_utils import DynamicCache
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,3 +66,58 @@ class DynamicFp8Cache(DynamicCache):
 | 
			
		|||
            self.value_cache[layer_idx] = v_cache
 | 
			
		||||
 | 
			
		||||
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DynamicNormalCache(DynamicCache):
 | 
			
		||||
    KV_ALLOC_BLOCK_LENGTH = 256
 | 
			
		||||
 | 
			
		||||
    def update(
 | 
			
		||||
        self,
 | 
			
		||||
        key_states: torch.Tensor,
 | 
			
		||||
        value_states: torch.Tensor,
 | 
			
		||||
        layer_idx: int,
 | 
			
		||||
        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
			
		||||
        new_layout=False,   # useless, just keep same as DynamicFp8Cache
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
 | 
			
		||||
        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
			
		||||
 | 
			
		||||
        if layer_idx == 0:
 | 
			
		||||
            if hasattr(self, "_seen_tokens"):
 | 
			
		||||
                # 4.39 uses `_seen_tokens`
 | 
			
		||||
                self._seen_tokens += seq_len
 | 
			
		||||
            else:
 | 
			
		||||
                # 4.37 uses `seen_tokens`
 | 
			
		||||
                self.seen_tokens += seq_len
 | 
			
		||||
 | 
			
		||||
        # Update the cache
 | 
			
		||||
        if len(self.key_cache) <= layer_idx:
 | 
			
		||||
            k_cache, v_cache = init_kv_cache(
 | 
			
		||||
                batch_size, num_heads, head_dim,
 | 
			
		||||
                0, key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                key_states.dtype, key_states.device
 | 
			
		||||
            )
 | 
			
		||||
            k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
 | 
			
		||||
            self.key_cache.append(k_cache)
 | 
			
		||||
            self.value_cache.append(v_cache)
 | 
			
		||||
        else:
 | 
			
		||||
            k_cache = self.key_cache[layer_idx]
 | 
			
		||||
            v_cache = self.value_cache[layer_idx]
 | 
			
		||||
 | 
			
		||||
            kv_seq_len = k_cache.size(2) + key_states.size(2)
 | 
			
		||||
            if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
 | 
			
		||||
                new_k_cache, new_v_cache = init_kv_cache(
 | 
			
		||||
                    batch_size, num_heads, head_dim,
 | 
			
		||||
                    k_cache.size(2), kv_seq_len + self.KV_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                    key_states.dtype, key_states.device
 | 
			
		||||
                )
 | 
			
		||||
                new_k_cache[...] = k_cache[...]
 | 
			
		||||
                new_v_cache[...] = v_cache[...]
 | 
			
		||||
                k_cache = new_k_cache
 | 
			
		||||
                v_cache = new_v_cache
 | 
			
		||||
            k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
			
		||||
            self.key_cache[layer_idx] = k_cache
 | 
			
		||||
            self.value_cache[layer_idx] = v_cache
 | 
			
		||||
 | 
			
		||||
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,7 +44,7 @@ from ipex_llm.transformers.models.utils import (
 | 
			
		|||
    use_quantize_kv_cache, restore_fp8_kv_cache,
 | 
			
		||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
			
		||||
)
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
			
		||||
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
			
		||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
 | 
			
		||||
from typing import Optional, Tuple, List
 | 
			
		||||
| 
						 | 
				
			
			@ -56,6 +56,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
 | 
			
		|||
def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
			
		||||
    use_fuse_rope = (
 | 
			
		||||
        hidden_states.device.type == "xpu" and
 | 
			
		||||
        hidden_states.numel() == hidden_states.size(-1) and
 | 
			
		||||
        not (self.training and hidden_states.requires_grad) and
 | 
			
		||||
        position_ids is not None
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -130,12 +131,8 @@ def attention_forward(
 | 
			
		|||
                      "`past_key_value` cannot be None")
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
 | 
			
		||||
 | 
			
		||||
    if use_quantize_kv:
 | 
			
		||||
    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                     self.layer_idx, None, new_layout=True)
 | 
			
		||||
    else:
 | 
			
		||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
			
		||||
                                                         self.layer_idx, None)
 | 
			
		||||
 | 
			
		||||
    if use_quantize_kv and q_len == 1:
 | 
			
		||||
        import linear_q4_0
 | 
			
		||||
| 
						 | 
				
			
			@ -188,9 +185,12 @@ def model_forward(
 | 
			
		|||
    return_dict: Optional[bool] = None,
 | 
			
		||||
):
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids):
 | 
			
		||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
 | 
			
		||||
            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
			
		||||
    return Starcoder2Model.forward(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue