parent
3ee6dec0f8
commit
6693e8ab04
3 changed files with 97 additions and 4 deletions
|
|
@ -22,7 +22,8 @@ import math
|
|||
|
||||
from .models.utils import (
|
||||
init_fp8_kv_cache, append_fp8_kv_cache,
|
||||
init_kv_cache, append_kv_cache, extend_kv_cache
|
||||
init_kv_cache, append_kv_cache, extend_kv_cache,
|
||||
init_unbalanced_fp8_kv_cache, append_unbalanced_fp8_kv_cache,
|
||||
)
|
||||
from typing import Optional, Dict, Tuple, Any, List
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
|
@ -151,6 +152,55 @@ class DynamicNormalCache(DynamicCache):
|
|||
return past_key_values
|
||||
|
||||
|
||||
class DynamicUnbalancedFp8Cache(DynamicCache):
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
# ignore num_hidden_layers to fix transformers >= 4.45
|
||||
super().__init__()
|
||||
|
||||
def update(
|
||||
self,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
layer_idx: int,
|
||||
cache_kwargs: Optional[Dict[str, Any]]=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# fix converting empty DynamicCache in transformers >= 4.45
|
||||
if key_states == []:
|
||||
return key_states, value_states
|
||||
|
||||
batch_size, num_heads, seq_len, k_head_dim = key_states.shape
|
||||
_, _, _, v_head_dim = value_states.shape
|
||||
|
||||
if layer_idx == 0:
|
||||
if hasattr(self, "_seen_tokens"):
|
||||
# 4.39 uses `_seen_tokens`
|
||||
self._seen_tokens += seq_len
|
||||
else:
|
||||
# 4.37 uses `seen_tokens`
|
||||
self.seen_tokens += seq_len
|
||||
|
||||
# Update the cache
|
||||
if len(self.key_cache) <= layer_idx:
|
||||
k_cache, v_cache = init_unbalanced_fp8_kv_cache(
|
||||
batch_size, num_heads, seq_len, k_head_dim, v_head_dim,
|
||||
device=key_states.device,
|
||||
)
|
||||
k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
|
||||
self.key_cache.append(k_cache)
|
||||
self.value_cache.append(v_cache)
|
||||
else:
|
||||
k_cache = self.key_cache[layer_idx]
|
||||
v_cache = self.value_cache[layer_idx]
|
||||
k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
|
||||
key_states, value_states)
|
||||
self.key_cache[layer_idx] = k_cache
|
||||
self.value_cache[layer_idx] = v_cache
|
||||
|
||||
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -273,11 +273,11 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
|
|||
else:
|
||||
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
|
||||
elif seq_length != kv_length and seq_length <= 32:
|
||||
# todo: add scale support
|
||||
# todo: add further scale support
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
|
||||
else:
|
||||
attn_output = xe_addons.sdp(query, key, value, mask)
|
||||
attn_output = xe_addons.sdp(query, key, value, mask, scale)
|
||||
else:
|
||||
if key.dtype == torch.uint8:
|
||||
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)
|
||||
|
|
|
|||
|
|
@ -138,6 +138,49 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value):
|
|||
return new_k_cache, new_v_cache
|
||||
|
||||
|
||||
def init_unbalanced_fp8_kv_cache(batch_size, num_heads, current_length,
|
||||
k_head_dim, v_head_dim, device):
|
||||
# for case which k head dim is different from v head dim
|
||||
max_length = current_length + FP8_KV_ALLOC_LENGTH
|
||||
|
||||
k_cache_storage = torch.empty(batch_size, num_heads, max_length, k_head_dim,
|
||||
dtype=torch.uint8, device=device)
|
||||
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, k_head_dim),
|
||||
k_cache_storage.stride(), storage_offset=0)
|
||||
|
||||
v_cache_storage = torch.empty(batch_size, num_heads, max_length, v_head_dim,
|
||||
dtype=torch.uint8, device=device)
|
||||
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, v_head_dim),
|
||||
v_cache_storage.stride(), storage_offset=0)
|
||||
return k_cache, v_cache
|
||||
|
||||
|
||||
def append_unbalanced_fp8_kv_cache(k_cache, v_cache, key, value):
|
||||
batch_size, num_heads, cur_length, k_head_dim = k_cache.shape
|
||||
_, _, _, v_head_dim = v_cache.shape
|
||||
new_length = cur_length + key.size(2)
|
||||
new_k_size = (batch_size, num_heads, new_length, k_head_dim)
|
||||
new_v_size = (batch_size, num_heads, new_length, v_head_dim)
|
||||
|
||||
if k_cache.stride(1) < new_length * k_cache.size(3):
|
||||
new_k_cache, new_v_cache = init_unbalanced_fp8_kv_cache(batch_size, num_heads, new_length,
|
||||
k_head_dim, v_head_dim, key.device)
|
||||
new_k_cache = new_k_cache.as_strided(new_k_size, new_k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = new_v_cache.as_strided(new_v_size, new_v_cache.stride(), storage_offset=0)
|
||||
new_k_cache[:, :, :cur_length, :] = k_cache
|
||||
new_v_cache[:, :, :cur_length, :] = v_cache
|
||||
else:
|
||||
new_k_cache = k_cache.as_strided(new_k_size, k_cache.stride(), storage_offset=0)
|
||||
new_v_cache = v_cache.as_strided(new_v_size, v_cache.stride(), storage_offset=0)
|
||||
|
||||
import xe_addons
|
||||
xe_addons.quantize_key_value(key, value,
|
||||
new_k_cache[:, :, cur_length:new_length, :],
|
||||
new_v_cache[:, :, cur_length:new_length, :])
|
||||
|
||||
return new_k_cache, new_v_cache
|
||||
|
||||
|
||||
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
|
||||
key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
|
||||
value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)
|
||||
|
|
|
|||
Loading…
Reference in a new issue