update quantize kv cache condition (#12681)

This commit is contained in:
Yishuo Wang 2025-01-09 15:23:04 +08:00 committed by GitHub
parent 5d8081afbc
commit 7234c9b27b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 75 additions and 37 deletions

View file

@ -73,7 +73,9 @@ def baichuan_model_7b_forward(
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads,
self.config.num_attention_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
@ -246,8 +248,6 @@ def baichuan_attention_forward_7b(
key_states = key_states.to(hidden_states.dtype)
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
@ -258,6 +258,8 @@ def baichuan_attention_forward_7b(
query_states, attention_mask, 1,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
@ -308,7 +310,8 @@ def baichuan_attention_forward_13b(
kv_seq_len += past_key_value[0].shape[2]
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device

View file

@ -63,8 +63,13 @@ def chatglm2_model_forward(
if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
n_heads = self.config.num_attention_heads
if self.config.multi_query_attention:
n_kv_heads = self.config.multi_query_group_num
else:
n_kv_heads = n_heads
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
input_ids)
input_ids, n_heads, n_kv_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
@ -257,8 +262,6 @@ def chatglm2_attention_forward(
key_states[..., :rot_dim] = k_rot[...]
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
@ -272,6 +275,8 @@ def chatglm2_attention_forward(
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device

View file

@ -55,8 +55,13 @@ def chatglm4_model_forward(
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
inputs)
n_heads = self.config.num_attention_heads
if self.config.multi_query_attention:
n_kv_heads = self.config.multi_query_group_num
else:
n_kv_heads = n_heads
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
n_heads, n_kv_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
@ -211,8 +216,6 @@ def chatglm4_attention_forward(
key_states[..., :rot_dim] = k_rot[...]
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
@ -226,6 +229,8 @@ def chatglm4_attention_forward(
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device

View file

@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
key_states[..., :rot_dim] = k_rot[...]
# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device

View file

@ -147,7 +147,7 @@ def glm_model_forward_wrapper(origin_forward):
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
if use_cache:

View file

@ -87,7 +87,8 @@ def internlm_attention_forward(
)
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
@ -171,7 +172,8 @@ def internlm2_attention_forward(
)
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
self.num_heads, self.num_key_value_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
@ -346,7 +348,8 @@ def internlm_xcomposser2_attention_forward(
query_states, key_states, cos, sin, position_ids, "internlm")
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
self.num_heads, self.num_key_value_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device

View file

@ -72,7 +72,7 @@ def llama_model_forward(
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
self.config.num_attention_heads, self.config.num_key_value_heads
)
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)

View file

@ -159,7 +159,7 @@ def minicpm_model_forward_wrapper(origin_forward):
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)

View file

@ -66,7 +66,9 @@ def minicpm3_model_forward_wrapper(origin_forward):
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -71,7 +71,7 @@ def mistral_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
isinstance(past_key_values, DynamicCompressCache)

View file

@ -113,7 +113,7 @@ def mllama_text_model_forward(
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
self.config.num_attention_heads, self.config.num_key_value_heads
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):

View file

@ -249,7 +249,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
@ -305,7 +307,9 @@ def phi3v_model_forward_wrapper(origin_model_forward):
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -107,7 +107,8 @@ def qwen_attention_forward(
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
@ -205,7 +206,8 @@ def qwen_attention_forward_registered(
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device

View file

@ -113,10 +113,10 @@ def qwen2_model_forward(
# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, num_heads, num_kv_heads)
)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
@ -305,10 +305,11 @@ def qwen2_model_forward_4_42(
# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
self.config.num_attention_heads//self.config.num_key_value_heads)
num_heads, num_kv_heads)
)
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)

View file

@ -73,8 +73,10 @@ def qwen2moe_model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input)
inputs = input_ids if input_ids is not None else inputs_embeds
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -88,7 +88,9 @@ def qwen2_vl_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -69,8 +69,10 @@ def stablelm_model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
num_heads, num_kv_heads))
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -132,7 +132,9 @@ def model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)

View file

@ -74,7 +74,8 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor,
num_heads: int, num_kv_heads: int) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
warnings.warn(
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
@ -90,8 +91,11 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
else:
device_name = get_xpu_device_name(x.device)
return (
device_name in ["mtl", "lnl", "arl"] and kv_group == 1
or device_name in ["arc", "bmg"] and x.size(0) > 1
num_kv_heads >= 4
and (
device_name in ["mtl", "lnl", "arl"] and num_heads // num_kv_heads <= 4
or device_name in ["arc", "bmg"] and x.size(0) > 1
)
)

View file

@ -158,7 +158,8 @@ def yuan_attention_forward(
"yuan")
# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
None if past_key_value is None else (past_key_value[0], past_key_value[1]),
key_states, value_states,