update quantize kv cache condition (#12681)
This commit is contained in:
parent
5d8081afbc
commit
7234c9b27b
20 changed files with 75 additions and 37 deletions
|
|
@ -73,7 +73,9 @@ def baichuan_model_7b_forward(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
|
self.config.num_attention_heads,
|
||||||
|
self.config.num_attention_heads)
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
|
|
@ -246,8 +248,6 @@ def baichuan_attention_forward_7b(
|
||||||
key_states = key_states.to(hidden_states.dtype)
|
key_states = key_states.to(hidden_states.dtype)
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
|
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
|
||||||
|
|
@ -258,6 +258,8 @@ def baichuan_attention_forward_7b(
|
||||||
query_states, attention_mask, 1,
|
query_states, attention_mask, 1,
|
||||||
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
else:
|
else:
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
|
||||||
|
self.num_heads, self.num_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device
|
kv_seq_len, use_quantize_kv, device
|
||||||
|
|
@ -308,7 +310,8 @@ def baichuan_attention_forward_13b(
|
||||||
kv_seq_len += past_key_value[0].shape[2]
|
kv_seq_len += past_key_value[0].shape[2]
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
|
||||||
|
self.num_heads, self.num_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device
|
kv_seq_len, use_quantize_kv, device
|
||||||
|
|
|
||||||
|
|
@ -63,8 +63,13 @@ def chatglm2_model_forward(
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
|
||||||
|
n_heads = self.config.num_attention_heads
|
||||||
|
if self.config.multi_query_attention:
|
||||||
|
n_kv_heads = self.config.multi_query_group_num
|
||||||
|
else:
|
||||||
|
n_kv_heads = n_heads
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
|
||||||
input_ids)
|
input_ids, n_heads, n_kv_heads)
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
|
|
@ -257,8 +262,6 @@ def chatglm2_attention_forward(
|
||||||
key_states[..., :rot_dim] = k_rot[...]
|
key_states[..., :rot_dim] = k_rot[...]
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
@ -272,6 +275,8 @@ def chatglm2_attention_forward(
|
||||||
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
|
||||||
|
n_head, n_kv_head)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,13 @@ def chatglm4_model_forward(
|
||||||
if use_cache:
|
if use_cache:
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
|
n_heads = self.config.num_attention_heads
|
||||||
inputs)
|
if self.config.multi_query_attention:
|
||||||
|
n_kv_heads = self.config.multi_query_group_num
|
||||||
|
else:
|
||||||
|
n_kv_heads = n_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
|
||||||
|
n_heads, n_kv_heads)
|
||||||
if use_compress_kv and not isinstance(past_key_values,
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
DynamicCompressCache):
|
DynamicCompressCache):
|
||||||
if use_quantize_kv:
|
if use_quantize_kv:
|
||||||
|
|
@ -211,8 +216,6 @@ def chatglm4_attention_forward(
|
||||||
key_states[..., :rot_dim] = k_rot[...]
|
key_states[..., :rot_dim] = k_rot[...]
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
|
||||||
|
|
||||||
# [CompressKV]
|
# [CompressKV]
|
||||||
if use_compresskv:
|
if use_compresskv:
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
@ -226,6 +229,8 @@ def chatglm4_attention_forward(
|
||||||
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
|
||||||
|
n_head, n_kv_head)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
|
|
||||||
|
|
@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
|
||||||
key_states[..., :rot_dim] = k_rot[...]
|
key_states[..., :rot_dim] = k_rot[...]
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv
|
# IPEX-LLM OPT: kv cache and quantize kv
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
|
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,7 @@ def glm_model_forward_wrapper(origin_forward):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
use_cache = use_cache or inputs.device.type == 'xpu'
|
use_cache = use_cache or inputs.device.type == 'xpu'
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
||||||
self.config.num_attention_heads //
|
self.config.num_attention_heads,
|
||||||
self.config.num_key_value_heads)
|
self.config.num_key_value_heads)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,8 @@ def internlm_attention_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
|
||||||
|
self.num_heads, self.num_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
|
@ -171,7 +172,8 @@ def internlm2_attention_forward(
|
||||||
)
|
)
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
|
||||||
|
self.num_heads, self.num_key_value_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, hidden_states.device
|
kv_seq_len, use_quantize_kv, hidden_states.device
|
||||||
|
|
@ -346,7 +348,8 @@ def internlm_xcomposser2_attention_forward(
|
||||||
query_states, key_states, cos, sin, position_ids, "internlm")
|
query_states, key_states, cos, sin, position_ids, "internlm")
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
|
||||||
|
self.num_heads, self.num_key_value_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device
|
kv_seq_len, use_quantize_kv, device
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ def llama_model_forward(
|
||||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(
|
use_quantize_kv = use_quantize_kv_cache(
|
||||||
self.layers[0].mlp.down_proj, inputs,
|
self.layers[0].mlp.down_proj, inputs,
|
||||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
)
|
)
|
||||||
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
isinstance(past_key_values, DynamicCompressCache)
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ def minicpm_model_forward_wrapper(origin_forward):
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
||||||
self.config.num_attention_heads //
|
self.config.num_attention_heads,
|
||||||
self.config.num_key_value_heads)
|
self.config.num_key_value_heads)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
isinstance(past_key_values, DynamicCompressCache)
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,9 @@ def minicpm3_model_forward_wrapper(origin_forward):
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
||||||
|
num_heads, num_kv_heads)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def mistral_model_forward(
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
use_cache = use_cache or inputs.device.type == 'xpu'
|
use_cache = use_cache or inputs.device.type == 'xpu'
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
||||||
self.config.num_attention_heads //
|
self.config.num_attention_heads,
|
||||||
self.config.num_key_value_heads)
|
self.config.num_key_value_heads)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
|
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
|
||||||
isinstance(past_key_values, DynamicCompressCache)
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ def mllama_text_model_forward(
|
||||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(
|
use_quantize_kv = use_quantize_kv_cache(
|
||||||
self.layers[0].mlp.down_proj, inputs,
|
self.layers[0].mlp.down_proj, inputs,
|
||||||
self.config.num_attention_heads // self.config.num_key_value_heads
|
self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
)
|
)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,9 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
|
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
||||||
|
num_heads, num_kv_heads)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
isinstance(past_key_values, DynamicCompressCache)
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
|
@ -305,7 +307,9 @@ def phi3v_model_forward_wrapper(origin_model_forward):
|
||||||
):
|
):
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
|
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids,
|
||||||
|
num_heads, num_kv_heads)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,8 @@ def qwen_attention_forward(
|
||||||
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
|
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
|
||||||
|
self.num_heads, self.num_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device
|
kv_seq_len, use_quantize_kv, device
|
||||||
|
|
@ -205,7 +206,8 @@ def qwen_attention_forward_registered(
|
||||||
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
|
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
|
||||||
|
self.num_heads, self.num_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
past_key_value, key_states, value_states,
|
past_key_value, key_states, value_states,
|
||||||
kv_seq_len, use_quantize_kv, device
|
kv_seq_len, use_quantize_kv, device
|
||||||
|
|
|
||||||
|
|
@ -113,10 +113,10 @@ def qwen2_model_forward(
|
||||||
# ipex-llm changes start
|
# ipex-llm changes start
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
use_quantize_kv = (
|
use_quantize_kv = (
|
||||||
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, num_heads, num_kv_heads)
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
|
||||||
)
|
)
|
||||||
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
|
||||||
isinstance(past_key_values, DynamicCompressCache)
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
@ -305,10 +305,11 @@ def qwen2_model_forward_4_42(
|
||||||
|
|
||||||
# ipex-llm changes start
|
# ipex-llm changes start
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
use_quantize_kv = (
|
use_quantize_kv = (
|
||||||
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
self.config.hidden_size != 3584 # disable quantize kv in specific model
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
|
||||||
self.config.num_attention_heads//self.config.num_key_value_heads)
|
num_heads, num_kv_heads)
|
||||||
)
|
)
|
||||||
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
|
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
|
||||||
isinstance(past_key_values, DynamicCompressCache)
|
isinstance(past_key_values, DynamicCompressCache)
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,10 @@ def qwen2moe_model_forward(
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input)
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, inputs,
|
||||||
|
num_heads, num_kv_heads)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,9 @@ def qwen2_vl_model_forward(
|
||||||
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
# IPEX-LLM OPT start: kv cache and quantize kv cache
|
||||||
inputs = input_ids if input_ids is not None else inputs_embeds
|
inputs = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_cache = True if inputs.device.type == "xpu" else use_cache
|
use_cache = True if inputs.device.type == "xpu" else use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
|
||||||
|
num_heads, num_kv_heads)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -69,8 +69,10 @@ def stablelm_model_forward(
|
||||||
):
|
):
|
||||||
# IPEX-LLM OPT: kv cache and quantize kv cache
|
# IPEX-LLM OPT: kv cache and quantize kv cache
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
|
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
|
||||||
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
|
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
|
||||||
|
num_heads, num_kv_heads))
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,9 @@ def model_forward(
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
|
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
|
||||||
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids,
|
||||||
|
num_heads, num_kv_heads)
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,8 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
|
||||||
return new_cache_k, new_cache_v
|
return new_cache_k, new_cache_v
|
||||||
|
|
||||||
|
|
||||||
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
|
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor,
|
||||||
|
num_heads: int, num_kv_heads: int) -> bool:
|
||||||
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
|
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
|
||||||
|
|
@ -90,8 +91,11 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
|
||||||
else:
|
else:
|
||||||
device_name = get_xpu_device_name(x.device)
|
device_name = get_xpu_device_name(x.device)
|
||||||
return (
|
return (
|
||||||
device_name in ["mtl", "lnl", "arl"] and kv_group == 1
|
num_kv_heads >= 4
|
||||||
or device_name in ["arc", "bmg"] and x.size(0) > 1
|
and (
|
||||||
|
device_name in ["mtl", "lnl", "arl"] and num_heads // num_kv_heads <= 4
|
||||||
|
or device_name in ["arc", "bmg"] and x.size(0) > 1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,8 @@ def yuan_attention_forward(
|
||||||
"yuan")
|
"yuan")
|
||||||
|
|
||||||
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
# IPEX-LLM OPT: kv cache and quantzie kv cache
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states)
|
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states,
|
||||||
|
self.num_heads, self.num_heads)
|
||||||
key_states, value_states = update_past_key_value(
|
key_states, value_states = update_past_key_value(
|
||||||
None if past_key_value is None else (past_key_value[0], past_key_value[1]),
|
None if past_key_value is None else (past_key_value[0], past_key_value[1]),
|
||||||
key_states, value_states,
|
key_states, value_states,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue