fix llama3.1/3.2 quantize kv check (#12302)
This commit is contained in:
parent
416c19165c
commit
72605c7016
2 changed files with 9 additions and 3 deletions
|
|
@ -79,7 +79,10 @@ def llama_model_forward(
|
||||||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
use_quantize_kv = use_quantize_kv_cache(
|
||||||
|
self.layers[0].mlp.down_proj, inputs,
|
||||||
|
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||||
|
)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
@ -114,7 +117,7 @@ def llama_model_forward(
|
||||||
|
|
||||||
# IPEX-LLM OPT start: use fused rope
|
# IPEX-LLM OPT start: use fused rope
|
||||||
if (should_use_fuse_rope(hidden_states, position_ids, False)
|
if (should_use_fuse_rope(hidden_states, position_ids, False)
|
||||||
and self.rotary_emb.rope_type == "llama3"):
|
and self.rotary_emb.rope_type in ["default", "llama3"]):
|
||||||
position_embeddings = self.rotary_emb.inv_freq
|
position_embeddings = self.rotary_emb.inv_freq
|
||||||
# IEPX_LLM OPT end
|
# IEPX_LLM OPT end
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,10 @@ def mllama_text_model_forward(
|
||||||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
use_quantize_kv = use_quantize_kv_cache(
|
||||||
|
self.layers[0].mlp.down_proj, inputs,
|
||||||
|
self.config.num_attention_heads // self.config.num_key_value_heads
|
||||||
|
)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue