optimize starcoder normal kv cache (#10642)
This commit is contained in:
		
							parent
							
								
									3a9ab8f1ae
								
							
						
					
					
						commit
						702e686901
					
				
					 2 changed files with 68 additions and 10 deletions
				
			
		| 
						 | 
					@ -17,7 +17,10 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .models.utils import init_fp8_kv_cache, append_fp8_kv_cache
 | 
					from .models.utils import (
 | 
				
			||||||
 | 
					    init_fp8_kv_cache, append_fp8_kv_cache,
 | 
				
			||||||
 | 
					    init_kv_cache, append_kv_cache
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from typing import Optional, Dict, Tuple, Any
 | 
					from typing import Optional, Dict, Tuple, Any
 | 
				
			||||||
from transformers.cache_utils import DynamicCache
 | 
					from transformers.cache_utils import DynamicCache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,3 +66,58 @@ class DynamicFp8Cache(DynamicCache):
 | 
				
			||||||
            self.value_cache[layer_idx] = v_cache
 | 
					            self.value_cache[layer_idx] = v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
					        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DynamicNormalCache(DynamicCache):
 | 
				
			||||||
 | 
					    KV_ALLOC_BLOCK_LENGTH = 256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        key_states: torch.Tensor,
 | 
				
			||||||
 | 
					        value_states: torch.Tensor,
 | 
				
			||||||
 | 
					        layer_idx: int,
 | 
				
			||||||
 | 
					        cache_kwargs: Optional[Dict[str, Any]]=None,
 | 
				
			||||||
 | 
					        new_layout=False,   # useless, just keep same as DynamicFp8Cache
 | 
				
			||||||
 | 
					    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        batch_size, num_heads, seq_len, head_dim = key_states.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if layer_idx == 0:
 | 
				
			||||||
 | 
					            if hasattr(self, "_seen_tokens"):
 | 
				
			||||||
 | 
					                # 4.39 uses `_seen_tokens`
 | 
				
			||||||
 | 
					                self._seen_tokens += seq_len
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                # 4.37 uses `seen_tokens`
 | 
				
			||||||
 | 
					                self.seen_tokens += seq_len
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Update the cache
 | 
				
			||||||
 | 
					        if len(self.key_cache) <= layer_idx:
 | 
				
			||||||
 | 
					            k_cache, v_cache = init_kv_cache(
 | 
				
			||||||
 | 
					                batch_size, num_heads, head_dim,
 | 
				
			||||||
 | 
					                0, key_states.size(2) + self.KV_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                key_states.dtype, key_states.device
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            self.key_cache.append(k_cache)
 | 
				
			||||||
 | 
					            self.value_cache.append(v_cache)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            k_cache = self.key_cache[layer_idx]
 | 
				
			||||||
 | 
					            v_cache = self.value_cache[layer_idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            kv_seq_len = k_cache.size(2) + key_states.size(2)
 | 
				
			||||||
 | 
					            if k_cache.stride(1) < kv_seq_len * k_cache.size(3):
 | 
				
			||||||
 | 
					                new_k_cache, new_v_cache = init_kv_cache(
 | 
				
			||||||
 | 
					                    batch_size, num_heads, head_dim,
 | 
				
			||||||
 | 
					                    k_cache.size(2), kv_seq_len + self.KV_ALLOC_BLOCK_LENGTH,
 | 
				
			||||||
 | 
					                    key_states.dtype, key_states.device
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                new_k_cache[...] = k_cache[...]
 | 
				
			||||||
 | 
					                new_v_cache[...] = v_cache[...]
 | 
				
			||||||
 | 
					                k_cache = new_k_cache
 | 
				
			||||||
 | 
					                v_cache = new_v_cache
 | 
				
			||||||
 | 
					            k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
 | 
				
			||||||
 | 
					            self.key_cache[layer_idx] = k_cache
 | 
				
			||||||
 | 
					            self.value_cache[layer_idx] = v_cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -44,7 +44,7 @@ from ipex_llm.transformers.models.utils import (
 | 
				
			||||||
    use_quantize_kv_cache, restore_fp8_kv_cache,
 | 
					    use_quantize_kv_cache, restore_fp8_kv_cache,
 | 
				
			||||||
    apply_rotary_pos_emb_no_cache_xpu
 | 
					    apply_rotary_pos_emb_no_cache_xpu
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from ipex_llm.transformers.kv import DynamicFp8Cache
 | 
					from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
 | 
				
			||||||
from ipex_llm.utils.common.log4Error import invalidInputError
 | 
					from ipex_llm.utils.common.log4Error import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import Optional, Tuple, List
 | 
					from typing import Optional, Tuple, List
 | 
				
			||||||
| 
						 | 
					@ -56,6 +56,7 @@ from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Model,
 | 
				
			||||||
def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
					def should_use_fuse_rope(self, hidden_states, position_ids):
 | 
				
			||||||
    use_fuse_rope = (
 | 
					    use_fuse_rope = (
 | 
				
			||||||
        hidden_states.device.type == "xpu" and
 | 
					        hidden_states.device.type == "xpu" and
 | 
				
			||||||
 | 
					        hidden_states.numel() == hidden_states.size(-1) and
 | 
				
			||||||
        not (self.training and hidden_states.requires_grad) and
 | 
					        not (self.training and hidden_states.requires_grad) and
 | 
				
			||||||
        position_ids is not None
 | 
					        position_ids is not None
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
| 
						 | 
					@ -130,12 +131,8 @@ def attention_forward(
 | 
				
			||||||
                      "`past_key_value` cannot be None")
 | 
					                      "`past_key_value` cannot be None")
 | 
				
			||||||
    use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
 | 
					    use_quantize_kv = use_quantize_kv_cache(self.o_proj, hidden_states)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_quantize_kv:
 | 
					    key_states, value_states = past_key_value.update(key_states, value_states,
 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					                                                     self.layer_idx, None, new_layout=True)
 | 
				
			||||||
                                                         self.layer_idx, None, new_layout=True)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        key_states, value_states = past_key_value.update(key_states, value_states,
 | 
					 | 
				
			||||||
                                                         self.layer_idx, None)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if use_quantize_kv and q_len == 1:
 | 
					    if use_quantize_kv and q_len == 1:
 | 
				
			||||||
        import linear_q4_0
 | 
					        import linear_q4_0
 | 
				
			||||||
| 
						 | 
					@ -188,9 +185,12 @@ def model_forward(
 | 
				
			||||||
    return_dict: Optional[bool] = None,
 | 
					    return_dict: Optional[bool] = None,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
					    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
				
			||||||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids):
 | 
					    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
 | 
				
			||||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
					    if use_cache:
 | 
				
			||||||
 | 
					        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
				
			||||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
					            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
				
			||||||
 | 
					        if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
 | 
				
			||||||
 | 
					            past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
				
			||||||
    return Starcoder2Model.forward(
 | 
					    return Starcoder2Model.forward(
 | 
				
			||||||
        self=self,
 | 
					        self=self,
 | 
				
			||||||
        input_ids=input_ids,
 | 
					        input_ids=input_ids,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue