fix llama3.1/3.2 quantize kv check (#12302)

This commit is contained in:
Yishuo Wang 2024-10-31 11:55:07 +08:00 committed by GitHub
parent 416c19165c
commit 72605c7016
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 3 deletions

View file

@ -79,7 +79,10 @@ def llama_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache # IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
if use_cache: if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
@ -114,7 +117,7 @@ def llama_model_forward(
# IPEX-LLM OPT start: use fused rope # IPEX-LLM OPT start: use fused rope
if (should_use_fuse_rope(hidden_states, position_ids, False) if (should_use_fuse_rope(hidden_states, position_ids, False)
and self.rotary_emb.rope_type == "llama3"): and self.rotary_emb.rope_type in ["default", "llama3"]):
position_embeddings = self.rotary_emb.inv_freq position_embeddings = self.rotary_emb.inv_freq
# IEPX_LLM OPT end # IEPX_LLM OPT end

View file

@ -129,7 +129,10 @@ def mllama_text_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache # IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
if use_cache: if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)