MiniCPM-V support compresskv (#11779)
* fix check error * fix other models * remove print
This commit is contained in:
		
							parent
							
								
									3998de14f0
								
							
						
					
					
						commit
						7cd6ec9723
					
				
					 7 changed files with 9 additions and 9 deletions
				
			
		| 
						 | 
				
			
			@ -87,7 +87,7 @@ def chatglm2_model_forward(
 | 
			
		|||
                                dtype=inputs_embeds.dtype, device=inputs_embeds.device)
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1])
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
 | 
			
		||||
                                                input_ids)
 | 
			
		||||
        if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,7 +51,7 @@ def chatglm4_model_forward(
 | 
			
		|||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
 | 
			
		||||
        use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h,
 | 
			
		||||
                                                inputs)
 | 
			
		||||
        if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -128,7 +128,7 @@ def llama_model_forward_4_36(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -168,7 +168,7 @@ def llama_model_forward_4_38(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -209,7 +209,7 @@ def llama_model_forward_4_41(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -628,7 +628,7 @@ def minicpm_model_forward(
 | 
			
		|||
                                 self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[-1]):
 | 
			
		||||
        elif should_use_compresskv(input, input.shape[1]):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -211,7 +211,7 @@ def mistral_model_forward_4_36(
 | 
			
		|||
                                 self.config.num_attention_heads//self.config.num_key_value_heads):
 | 
			
		||||
            if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
                past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
        elif should_use_compresskv(input_ids, input_ids.shape[-1]):
 | 
			
		||||
        elif should_use_compresskv(input_ids, input_ids.shape[1]):
 | 
			
		||||
            # if use quantize kv, compress kv will be ignored now
 | 
			
		||||
            if not isinstance(past_key_values, DynamicCompressCache):
 | 
			
		||||
                past_key_values = DynamicCompressCache.from_legacy_cache(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -258,7 +258,7 @@ def phi3_model_forward_wrapper(origin_model_forward):
 | 
			
		|||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
        input = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
        use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input, input.shape[-1])
 | 
			
		||||
        use_compress_kv = should_use_compresskv(input, input.shape[1])
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            if use_compress_kv and not isinstance(past_key_values,
 | 
			
		||||
                                                  DynamicCompressCache):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -118,7 +118,7 @@ def qwen2_model_forward(
 | 
			
		|||
        and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
 | 
			
		||||
                                  self.config.num_attention_heads//self.config.num_key_value_heads)
 | 
			
		||||
    )
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1])
 | 
			
		||||
    use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue