fix llama3.1/3.2 quantize kv check (#12302)
This commit is contained in:
		
							parent
							
								
									416c19165c
								
							
						
					
					
						commit
						72605c7016
					
				
					 2 changed files with 9 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -79,7 +79,10 @@ def llama_model_forward(
 | 
			
		|||
    # IPEX-LLM OPT start: kv cache and quantize kv cache
 | 
			
		||||
    inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    use_cache = True if inputs.device.type == "xpu" else use_cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(
 | 
			
		||||
        self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
        self.config.num_attention_heads // self.config.num_key_value_heads
 | 
			
		||||
    )
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			@ -114,7 +117,7 @@ def llama_model_forward(
 | 
			
		|||
 | 
			
		||||
    # IPEX-LLM OPT start: use fused rope
 | 
			
		||||
    if (should_use_fuse_rope(hidden_states, position_ids, False)
 | 
			
		||||
            and self.rotary_emb.rope_type == "llama3"):
 | 
			
		||||
            and self.rotary_emb.rope_type in ["default", "llama3"]):
 | 
			
		||||
        position_embeddings = self.rotary_emb.inv_freq
 | 
			
		||||
    # IEPX_LLM OPT end
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -129,7 +129,10 @@ def mllama_text_model_forward(
 | 
			
		|||
    # IPEX-LLM OPT start: kv cache and quantize kv cache
 | 
			
		||||
    inputs = input_ids if input_ids is not None else inputs_embeds
 | 
			
		||||
    use_cache = True if inputs.device.type == "xpu" else use_cache
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
 | 
			
		||||
    use_quantize_kv = use_quantize_kv_cache(
 | 
			
		||||
        self.layers[0].mlp.down_proj, inputs,
 | 
			
		||||
        self.config.num_attention_heads // self.config.num_key_value_heads
 | 
			
		||||
    )
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue