diff --git a/python/llm/src/ipex_llm/transformers/models/llama32.py b/python/llm/src/ipex_llm/transformers/models/llama32.py index f105669f..cad33744 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama32.py +++ b/python/llm/src/ipex_llm/transformers/models/llama32.py @@ -79,7 +79,10 @@ def llama_model_forward( # IPEX-LLM OPT start: kv cache and quantize kv cache inputs = input_ids if input_ids is not None else inputs_embeds use_cache = True if inputs.device.type == "xpu" else use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + use_quantize_kv = use_quantize_kv_cache( + self.layers[0].mlp.down_proj, inputs, + self.config.num_attention_heads // self.config.num_key_value_heads + ) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) @@ -114,7 +117,7 @@ def llama_model_forward( # IPEX-LLM OPT start: use fused rope if (should_use_fuse_rope(hidden_states, position_ids, False) - and self.rotary_emb.rope_type == "llama3"): + and self.rotary_emb.rope_type in ["default", "llama3"]): position_embeddings = self.rotary_emb.inv_freq # IEPX_LLM OPT end diff --git a/python/llm/src/ipex_llm/transformers/models/mllama.py b/python/llm/src/ipex_llm/transformers/models/mllama.py index 4a05346e..9086fd2a 100644 --- a/python/llm/src/ipex_llm/transformers/models/mllama.py +++ b/python/llm/src/ipex_llm/transformers/models/mllama.py @@ -129,7 +129,10 @@ def mllama_text_model_forward( # IPEX-LLM OPT start: kv cache and quantize kv cache inputs = input_ids if input_ids is not None else inputs_embeds use_cache = True if inputs.device.type == "xpu" else use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + use_quantize_kv = use_quantize_kv_cache( + self.layers[0].mlp.down_proj, inputs, + self.config.num_attention_heads // self.config.num_key_value_heads + ) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)