optimize phi3 memory usage (#11867)
This commit is contained in:
		
							parent
							
								
									5b83493b1a
								
							
						
					
					
						commit
						d4ee0a89f3
					
				
					 2 changed files with 26 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -121,6 +121,21 @@ class DynamicNormalCache(DynamicCache):
 | 
			
		|||
 | 
			
		||||
        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_reserved(cls, layers: int,
 | 
			
		||||
                      bsz: int, n_head: int, length: int, head_dim: int,
 | 
			
		||||
                      dtype: torch.dtype, device: torch.device):
 | 
			
		||||
        past_key_values = cls()
 | 
			
		||||
        for _i in range(layers):
 | 
			
		||||
            k_cache, v_cache = init_kv_cache(
 | 
			
		||||
                bsz, n_head, head_dim,
 | 
			
		||||
                0, length + cls.KV_ALLOC_BLOCK_LENGTH,
 | 
			
		||||
                dtype, device
 | 
			
		||||
            )
 | 
			
		||||
            past_key_values.key_cache.append(k_cache)
 | 
			
		||||
            past_key_values.value_cache.append(v_cache)
 | 
			
		||||
        return past_key_values
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
 | 
			
		||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -254,9 +254,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
 | 
			
		|||
    ):
 | 
			
		||||
        # IPEX-LLM OPT: kv cache and quantize kv cache and sdp
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input, input.shape[1])
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if use_compress_kv and not isinstance(past_key_values,
 | 
			
		||||
                                                  DynamicCompressCache):
 | 
			
		||||
| 
						 | 
				
			
			@ -272,6 +272,14 @@ def phi3_model_forward_wrapper(origin_model_forward):
 | 
			
		|||
                                                                               DynamicCompressCache
 | 
			
		||||
                                                                               )):
 | 
			
		||||
                past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
 | 
			
		||||
                if past_key_values.get_seq_length() == 0:
 | 
			
		||||
                    n_layer = self.config.num_hidden_layers
 | 
			
		||||
                    n_head = self.config.num_attention_heads
 | 
			
		||||
                    head_dim = self.config.hidden_size // self.config.num_attention_heads
 | 
			
		||||
                    past_key_values = DynamicNormalCache.from_reserved(
 | 
			
		||||
                        n_layer, inputs.size(0), n_head, inputs.size(1), head_dim,
 | 
			
		||||
                        inputs.dtype, inputs.device
 | 
			
		||||
                    )
 | 
			
		||||
        return origin_model_forward(
 | 
			
		||||
            self=self,
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue