fix llama3.1/3.2 quantize kv check (#12302)
This commit is contained in:
parent
416c19165c
commit
72605c7016
2 changed files with 9 additions and 3 deletions
|
|
@ -79,7 +79,10 @@ def llama_model_forward(
|
|||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||
use_quantize_kv = use_quantize_kv_cache(
|
||||
self.layers[0].mlp.down_proj, inputs,
|
||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||
)
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
|
|
@ -114,7 +117,7 @@ def llama_model_forward(
|
|||
|
||||
# IPEX-LLM OPT start: use fused rope
|
||||
if (should_use_fuse_rope(hidden_states, position_ids, False)
|
||||
and self.rotary_emb.rope_type == "llama3"):
|
||||
and self.rotary_emb.rope_type in ["default", "llama3"]):
|
||||
position_embeddings = self.rotary_emb.inv_freq
|
||||
# IEPX_LLM OPT end
|
||||
|
||||
|
|
|
|||
|
|
@ -129,7 +129,10 @@ def mllama_text_model_forward(
|
|||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
||||
use_quantize_kv = use_quantize_kv_cache(
|
||||
self.layers[0].mlp.down_proj, inputs,
|
||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||
)
|
||||
if use_cache:
|
||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||
|
|
|
|||
Loading…
Reference in a new issue