Phi3 support compresskv (#11733)
* phi3 support compresskv * fix phi3 mtl error * fix conflict with quant kv * fix abnormal on mtl * fix style * use slide windows size to compress kv * support sliding window * fix style * fix style * temp: partial support quant kv * support quant kv with compress kv, todo: model check * temp * fix style * fix style * remove prepare * address comment * default -> 1.8k
This commit is contained in:
parent
d8808cc2e3
commit
dd46c141bd
3 changed files with 146 additions and 82 deletions
|
|
@ -154,6 +154,11 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
|
||||||
bsz, num_heads, q_len, head_dim = query_states.shape
|
bsz, num_heads, q_len, head_dim = query_states.shape
|
||||||
if q_len <= attn_config.max_capacity_prompt:
|
if q_len <= attn_config.max_capacity_prompt:
|
||||||
return key_states, value_states
|
return key_states, value_states
|
||||||
|
else:
|
||||||
|
sliding_window_size = getattr(attn_config, "sliding_window", None)
|
||||||
|
if sliding_window_size is not None and sliding_window_size <= 2500:
|
||||||
|
return key_states[:, :, -sliding_window_size:, :], \
|
||||||
|
value_states[:, :, -sliding_window_size:, :]
|
||||||
else:
|
else:
|
||||||
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
|
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
|
||||||
attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
|
attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
|
||||||
|
|
@ -166,7 +171,8 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
|
||||||
mask = mask.to(attn_weights.device)
|
mask = mask.to(attn_weights.device)
|
||||||
attention_mask = mask[None, None, :, :]
|
attention_mask = mask[None, None, :, :]
|
||||||
|
|
||||||
attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask
|
attn_weights[:, :, -attn_config.window_size:,
|
||||||
|
-attn_config.window_size:] += attention_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1,
|
||||||
dtype=torch.float32).to(query_states.dtype)
|
dtype=torch.float32).to(query_states.dtype)
|
||||||
|
|
@ -174,7 +180,8 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
|
||||||
:-attn_config.window_size].sum(dim=-2)
|
:-attn_config.window_size].sum(dim=-2)
|
||||||
if attn_config.pooling == 'avgpool':
|
if attn_config.pooling == 'avgpool':
|
||||||
if num_key_value_groups > 1:
|
if num_key_value_groups > 1:
|
||||||
attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups,
|
attn_cache = F.avg_pool2d(attn_weights_sum,
|
||||||
|
kernel_size=(num_key_value_groups,
|
||||||
attn_config.kernel_size),
|
attn_config.kernel_size),
|
||||||
padding=(0, attn_config.kernel_size//2),
|
padding=(0, attn_config.kernel_size//2),
|
||||||
stride=(num_key_value_groups, 1))
|
stride=(num_key_value_groups, 1))
|
||||||
|
|
@ -196,10 +203,10 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
|
||||||
indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
|
indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
|
||||||
dim=-1).indices
|
dim=-1).indices
|
||||||
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
|
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
|
||||||
k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2,
|
k_past_compress = key_states[:, :, :-attn_config.window_size, :]\
|
||||||
index=indices)
|
.gather(dim=2, index=indices)
|
||||||
v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2,
|
v_past_compress = value_states[:, :, :-attn_config.window_size, :]\
|
||||||
index=indices)
|
.gather(dim=2, index=indices)
|
||||||
k_cur = key_states[:, :, -attn_config.window_size:, :]
|
k_cur = key_states[:, :, -attn_config.window_size:, :]
|
||||||
v_cur = value_states[:, :, -attn_config.window_size:, :]
|
v_cur = value_states[:, :, -attn_config.window_size:, :]
|
||||||
key_states = torch.cat([k_past_compress, k_cur], dim=2)
|
key_states = torch.cat([k_past_compress, k_cur], dim=2)
|
||||||
|
|
@ -208,9 +215,11 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
|
||||||
|
|
||||||
|
|
||||||
class DynamicCompressCache(DynamicCache):
|
class DynamicCompressCache(DynamicCache):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, quant_kv=False, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.real_kv_len = 0
|
self.real_kv_len = 0
|
||||||
|
self.quant_kv = quant_kv
|
||||||
|
self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache
|
||||||
|
|
||||||
def update_seen_tokens(self, layer_idx, q_len):
|
def update_seen_tokens(self, layer_idx, q_len):
|
||||||
if layer_idx == 0:
|
if layer_idx == 0:
|
||||||
|
|
@ -260,33 +269,46 @@ class DynamicCompressCache(DynamicCache):
|
||||||
self.key_cache.append(key_states_compress)
|
self.key_cache.append(key_states_compress)
|
||||||
self.value_cache.append(value_states_compress)
|
self.value_cache.append(value_states_compress)
|
||||||
|
|
||||||
|
if not self.quant_kv:
|
||||||
k_cache_compressed, v_cache_compressed = init_kv_cache(
|
k_cache_compressed, v_cache_compressed = init_kv_cache(
|
||||||
bsz, num_heads, head_dim,
|
bsz, num_heads, head_dim,
|
||||||
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
|
||||||
key_states.dtype, key_states.device
|
key_states.dtype, key_states.device
|
||||||
)
|
)
|
||||||
k_cache_compressed, v_cache_compressed = append_kv_cache(
|
else:
|
||||||
|
k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
|
||||||
|
bsz, num_heads, seq_len, head_dim,
|
||||||
|
device=key_states.device,
|
||||||
|
)
|
||||||
|
k_cache_compressed, v_cache_compressed = self.append_kv_func(
|
||||||
k_cache_compressed, v_cache_compressed,
|
k_cache_compressed, v_cache_compressed,
|
||||||
key_states_compress, value_states_compress)
|
key_states_compress, value_states_compress)
|
||||||
self.key_cache[layer_idx] = k_cache_compressed
|
self.key_cache[layer_idx] = k_cache_compressed
|
||||||
self.value_cache[layer_idx] = v_cache_compressed
|
self.value_cache[layer_idx] = v_cache_compressed
|
||||||
|
|
||||||
if key_states.stride(2) != head_dim:
|
if key_states.stride(2) != head_dim:
|
||||||
|
if not self.quant_kv:
|
||||||
k_cache, v_cache = init_kv_cache(
|
k_cache, v_cache = init_kv_cache(
|
||||||
bsz, num_heads, head_dim,
|
bsz, num_heads, head_dim,
|
||||||
0, key_states.size(2),
|
0, key_states.size(2),
|
||||||
key_states.dtype, key_states.device
|
key_states.dtype, key_states.device
|
||||||
)
|
)
|
||||||
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
|
else:
|
||||||
|
k_cache, v_cache = init_fp8_kv_cache(
|
||||||
|
bsz, num_heads, 0, head_dim, key_states.device
|
||||||
|
)
|
||||||
|
k_cache, v_cache = self.append_kv_func(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
else:
|
else:
|
||||||
return key_states, value_states
|
return key_states, value_states
|
||||||
else:
|
else:
|
||||||
cache_k = self.key_cache[layer_idx]
|
cache_k = self.key_cache[layer_idx]
|
||||||
cache_v = self.value_cache[layer_idx]
|
cache_v = self.value_cache[layer_idx]
|
||||||
if not enough_kv_room:
|
if not enough_kv_room and not self.quant_kv:
|
||||||
# allocate new
|
# allocate new
|
||||||
new_c_k, new_c_v = extend_kv_cache(bsz,
|
new_c_k, new_c_v = extend_kv_cache(
|
||||||
|
bsz,
|
||||||
num_heads, # Support GQA
|
num_heads, # Support GQA
|
||||||
head_dim,
|
head_dim,
|
||||||
cache_k.size(2),
|
cache_k.size(2),
|
||||||
|
|
@ -299,7 +321,7 @@ class DynamicCompressCache(DynamicCache):
|
||||||
cache_k = new_c_k
|
cache_k = new_c_k
|
||||||
cache_v = new_c_v
|
cache_v = new_c_v
|
||||||
|
|
||||||
key_states, value_states = append_kv_cache(cache_k,
|
key_states, value_states = self.append_kv_func(cache_k,
|
||||||
cache_v,
|
cache_v,
|
||||||
key_states,
|
key_states,
|
||||||
value_states)
|
value_states)
|
||||||
|
|
@ -316,3 +338,14 @@ class DynamicCompressCache(DynamicCache):
|
||||||
if len(self.key_cache) <= layer_idx:
|
if len(self.key_cache) <= layer_idx:
|
||||||
return 0
|
return 0
|
||||||
return self.real_kv_len
|
return self.real_kv_len
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
quantize_kv: Optional[bool] = False) -> "DynamicCache":
|
||||||
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
||||||
|
cache = cls(quantize_kv)
|
||||||
|
if past_key_values is not None:
|
||||||
|
for layer_idx in range(len(past_key_values)):
|
||||||
|
key_states, value_states = past_key_values[layer_idx]
|
||||||
|
cache.update(key_states, value_states, layer_idx)
|
||||||
|
return cache
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -40,11 +41,13 @@ from ipex_llm.transformers.models.utils import should_use_fuse_rope, rotate_half
|
||||||
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
|
||||||
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
|
||||||
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
|
||||||
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
|
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
|
||||||
|
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from transformers.models.phi.modeling_phi import repeat_kv
|
from transformers.models.phi.modeling_phi import repeat_kv
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
|
@ -94,6 +97,9 @@ def attention_forward(
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# [CompressKV]
|
||||||
|
use_compresskv = isinstance(past_key_value, DynamicCompressCache)
|
||||||
|
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
|
||||||
qkv = qkv.transpose(1, 2)
|
qkv = qkv.transpose(1, 2)
|
||||||
|
|
@ -127,12 +133,26 @@ def attention_forward(
|
||||||
cos, sin, position_ids)
|
cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx,
|
||||||
|
query_states, attention_mask, self.num_key_value_groups,
|
||||||
|
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
|
||||||
|
else:
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states,
|
key_states, value_states = past_key_value.update(key_states, value_states,
|
||||||
self.layer_idx, None)
|
self.layer_idx, None)
|
||||||
|
|
||||||
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
|
||||||
|
# [CompressKV]
|
||||||
|
if use_compresskv:
|
||||||
|
# print(attention_mask.shape)
|
||||||
|
context_len = key_states.size(2)
|
||||||
|
attention_mask = attention_mask[:, :, :, -context_len:]
|
||||||
import xe_addons
|
import xe_addons
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value,
|
||||||
|
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
|
||||||
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
|
||||||
attention_mask)
|
attention_mask)
|
||||||
else:
|
else:
|
||||||
|
|
@ -148,7 +168,8 @@ def attention_forward(
|
||||||
# attn_output = xe_addons.sdp_causal(query_states, key_states,
|
# attn_output = xe_addons.sdp_causal(query_states, key_states,
|
||||||
# value_states, attention_mask)
|
# value_states, attention_mask)
|
||||||
else:
|
else:
|
||||||
if isinstance(past_key_value, DynamicFp8Cache):
|
if isinstance(past_key_value,
|
||||||
|
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
|
||||||
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
|
||||||
query_states.dtype)
|
query_states.dtype)
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
|
@ -235,10 +256,20 @@ def phi3_model_forward_wrapper(origin_model_forward):
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
input = input_ids if input_ids is not None else inputs_embeds
|
input = input_ids if input_ids is not None else inputs_embeds
|
||||||
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
|
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
|
||||||
|
use_compress_kv = should_use_compresskv(input, input.shape[-1])
|
||||||
if use_cache:
|
if use_cache:
|
||||||
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
|
if use_compress_kv and not isinstance(past_key_values,
|
||||||
|
DynamicCompressCache):
|
||||||
|
past_key_values = DynamicCompressCache.\
|
||||||
|
from_legacy_cache(past_key_values,
|
||||||
|
quantize_kv=use_quantize_kv)
|
||||||
|
if use_quantize_kv and not isinstance(past_key_values,
|
||||||
|
(DynamicFp8Cache, DynamicCompressCache)):
|
||||||
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
|
||||||
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
|
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
|
||||||
|
(DynamicNormalCache,
|
||||||
|
DynamicCompressCache
|
||||||
|
)):
|
||||||
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
|
||||||
return origin_model_forward(
|
return origin_model_forward(
|
||||||
self=self,
|
self=self,
|
||||||
|
|
|
||||||
|
|
@ -490,7 +490,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int):
|
||||||
if use_compress_kv is None:
|
if use_compress_kv is None:
|
||||||
return (
|
return (
|
||||||
get_xpu_device_type(x) == "mtl"
|
get_xpu_device_type(x) == "mtl"
|
||||||
and prompt_len >= 2500
|
and prompt_len >= 1800
|
||||||
and prompt_len <= 4500
|
and prompt_len <= 4500
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue