parent
3ee6dec0f8
commit
6693e8ab04
3 changed files with 97 additions and 4 deletions
|
|
@ -22,7 +22,8 @@ import math
|
||||||
|
|
||||||
from .models.utils import (
|
from .models.utils import (
|
||||||
init_fp8_kv_cache, append_fp8_kv_cache,
|
init_fp8_kv_cache, append_fp8_kv_cache,
|
||||||
init_kv_cache, append_kv_cache, extend_kv_cache
|
init_kv_cache, append_kv_cache, extend_kv_cache,
|
||||||
|
init_unbalanced_fp8_kv_cache, append_unbalanced_fp8_kv_cache,
|
||||||
)
|
)
|
||||||
from typing import Optional, Dict, Tuple, Any, List
|
from typing import Optional, Dict, Tuple, Any, List
|
||||||
from transformers.cache_utils import DynamicCache
|
from transformers.cache_utils import DynamicCache
|
||||||
|
|
@ -151,6 +152,55 @@ class DynamicNormalCache(DynamicCache):
|
||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicUnbalancedFp8Cache(DynamicCache):
|
||||||
|
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||||
|
# ignore num_hidden_layers to fix transformers >= 4.45
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# fix converting empty DynamicCache in transformers >= 4.45
|
||||||
|
if key_states == []:
|
||||||
|
return key_states, value_states
|
||||||
|
|
||||||
|
batch_size, num_heads, seq_len, k_head_dim = key_states.shape
|
||||||
|
_, _, _, v_head_dim = value_states.shape
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
if hasattr(self, "_seen_tokens"):
|
||||||
|
# 4.39 uses `_seen_tokens`
|
||||||
|
self._seen_tokens += seq_len
|
||||||
|
else:
|
||||||
|
# 4.37 uses `seen_tokens`
|
||||||
|
self.seen_tokens += seq_len
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.key_cache) <= layer_idx:
|
||||||
|
k_cache, v_cache = init_unbalanced_fp8_kv_cache(
|
||||||
|
batch_size, num_heads, seq_len, k_head_dim, v_head_dim,
|
||||||
|
device=key_states.device,
|
||||||
|
)
|
||||||
|
k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
|
||||||
|
self.key_cache.append(k_cache)
|
||||||
|
self.value_cache.append(v_cache)
|
||||||
|
else:
|
||||||
|
k_cache = self.key_cache[layer_idx]
|
||||||
|
v_cache = self.value_cache[layer_idx]
|
||||||
|
k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
|
||||||
|
key_states, value_states)
|
||||||
|
self.key_cache[layer_idx] = k_cache
|
||||||
|
self.value_cache[layer_idx] = v_cache
|
||||||
|
|
||||||
|
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -273,11 +273,11 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||||
elif seq_length != kv_length and seq_length <= 32:
|
elif seq_length != kv_length and seq_length <= 32:
|
||||||
# todo: add scale support
|
# todo: add further scale support
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||||
else:
|
else:
|
||||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
attn_output = xe_addons.sdp(query, key, value, mask, scale)
|
||||||
else:
|
else:
|
||||||
if key.dtype == torch.uint8:
|
if key.dtype == torch.uint8:
|
||||||
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
|
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
|
||||||
|
|
|
||||||
|
|
@ -138,6 +138,49 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||||
return new_k_cache, new_v_cache
|
return new_k_cache, new_v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def init_unbalanced_fp8_kv_cache(batch_size, num_heads, current_length,
|
||||||
|
k_head_dim, v_head_dim, device):
|
||||||
|
# for case which k head dim is different from v head dim
|
||||||
|
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
||||||
|
|
||||||
|
k_cache_storage = torch.empty(batch_size, num_heads, max_length, k_head_dim,
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
|
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, k_head_dim),
|
||||||
|
k_cache_storage.stride(), storage_offset=0)
|
||||||
|
|
||||||
|
v_cache_storage = torch.empty(batch_size, num_heads, max_length, v_head_dim,
|
||||||
|
dtype=torch.uint8, device=device)
|
||||||
|
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, v_head_dim),
|
||||||
|
v_cache_storage.stride(), storage_offset=0)
|
||||||
|
return k_cache, v_cache
|
||||||
|
|
||||||
|
|
||||||
|
def append_unbalanced_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||||
|
batch_size, num_heads, cur_length, k_head_dim = k_cache.shape
|
||||||
|
_, _, _, v_head_dim = v_cache.shape
|
||||||
|
new_length = cur_length + key.size(2)
|
||||||
|
new_k_size = (batch_size, num_heads, new_length, k_head_dim)
|
||||||
|
new_v_size = (batch_size, num_heads, new_length, v_head_dim)
|
||||||
|
|
||||||
|
if k_cache.stride(1) < new_length * k_cache.size(3):
|
||||||
|
new_k_cache, new_v_cache = init_unbalanced_fp8_kv_cache(batch_size, num_heads, new_length,
|
||||||
|
k_head_dim, v_head_dim, key.device)
|
||||||
|
new_k_cache = new_k_cache.as_strided(new_k_size, new_k_cache.stride(), storage_offset=0)
|
||||||
|
new_v_cache = new_v_cache.as_strided(new_v_size, new_v_cache.stride(), storage_offset=0)
|
||||||
|
new_k_cache[:, :, :cur_length, :] = k_cache
|
||||||
|
new_v_cache[:, :, :cur_length, :] = v_cache
|
||||||
|
else:
|
||||||
|
new_k_cache = k_cache.as_strided(new_k_size, k_cache.stride(), storage_offset=0)
|
||||||
|
new_v_cache = v_cache.as_strided(new_v_size, v_cache.stride(), storage_offset=0)
|
||||||
|
|
||||||
|
import xe_addons
|
||||||
|
xe_addons.quantize_key_value(key, value,
|
||||||
|
new_k_cache[:, :, cur_length:new_length, :],
|
||||||
|
new_v_cache[:, :, cur_length:new_length, :])
|
||||||
|
|
||||||
|
return new_k_cache, new_v_cache
|
||||||
|
|
||||||
|
|
||||||
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
|
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
|
||||||
key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
|
key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
|
||||||
value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
|
value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue