Deepseek kv / sdp support (#13068)

* update kv

* fix

* fix style
This commit is contained in:
Ruonan Wang 2025-04-11 11:26:15 +08:00 committed by GitHub
parent 3ee6dec0f8
commit 6693e8ab04
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 97 additions and 4 deletions

View file

@ -22,7 +22,8 @@ import math
from .models.utils import (
init_fp8_kv_cache, append_fp8_kv_cache,
init_kv_cache, append_kv_cache, extend_kv_cache
init_kv_cache, append_kv_cache, extend_kv_cache,
init_unbalanced_fp8_kv_cache, append_unbalanced_fp8_kv_cache,
)
from typing import Optional, Dict, Tuple, Any, List
from transformers.cache_utils import DynamicCache
@ -151,6 +152,55 @@ class DynamicNormalCache(DynamicCache):
return past_key_values
class DynamicUnbalancedFp8Cache(DynamicCache):
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
# ignore num_hidden_layers to fix transformers >= 4.45
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]]=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# fix converting empty DynamicCache in transformers >= 4.45
if key_states == []:
return key_states, value_states
batch_size, num_heads, seq_len, k_head_dim = key_states.shape
_, _, _, v_head_dim = value_states.shape
if layer_idx == 0:
if hasattr(self, "_seen_tokens"):
# 4.39 uses `_seen_tokens`
self._seen_tokens += seq_len
else:
# 4.37 uses `seen_tokens`
self.seen_tokens += seq_len
# Update the cache
if len(self.key_cache) <= layer_idx:
k_cache, v_cache = init_unbalanced_fp8_kv_cache(
batch_size, num_heads, seq_len, k_head_dim, v_head_dim,
device=key_states.device,
)
k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
self.key_cache.append(k_cache)
self.value_cache.append(v_cache)
else:
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache,
key_states, value_states)
self.key_cache[layer_idx] = k_cache
self.value_cache[layer_idx] = v_cache
return self.key_cache[layer_idx], self.value_cache[layer_idx]
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""

View file

@ -273,11 +273,11 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
# todo: add further scale support
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
attn_output = xe_addons.sdp(query, key, value, mask, scale)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale)

View file

@ -138,6 +138,49 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value):
return new_k_cache, new_v_cache
def init_unbalanced_fp8_kv_cache(batch_size, num_heads, current_length,
k_head_dim, v_head_dim, device):
# for case which k head dim is different from v head dim
max_length = current_length + FP8_KV_ALLOC_LENGTH
k_cache_storage = torch.empty(batch_size, num_heads, max_length, k_head_dim,
dtype=torch.uint8, device=device)
k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, k_head_dim),
k_cache_storage.stride(), storage_offset=0)
v_cache_storage = torch.empty(batch_size, num_heads, max_length, v_head_dim,
dtype=torch.uint8, device=device)
v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, v_head_dim),
v_cache_storage.stride(), storage_offset=0)
return k_cache, v_cache
def append_unbalanced_fp8_kv_cache(k_cache, v_cache, key, value):
batch_size, num_heads, cur_length, k_head_dim = k_cache.shape
_, _, _, v_head_dim = v_cache.shape
new_length = cur_length + key.size(2)
new_k_size = (batch_size, num_heads, new_length, k_head_dim)
new_v_size = (batch_size, num_heads, new_length, v_head_dim)
if k_cache.stride(1) < new_length * k_cache.size(3):
new_k_cache, new_v_cache = init_unbalanced_fp8_kv_cache(batch_size, num_heads, new_length,
k_head_dim, v_head_dim, key.device)
new_k_cache = new_k_cache.as_strided(new_k_size, new_k_cache.stride(), storage_offset=0)
new_v_cache = new_v_cache.as_strided(new_v_size, new_v_cache.stride(), storage_offset=0)
new_k_cache[:, :, :cur_length, :] = k_cache
new_v_cache[:, :, :cur_length, :] = v_cache
else:
new_k_cache = k_cache.as_strided(new_k_size, k_cache.stride(), storage_offset=0)
new_v_cache = v_cache.as_strided(new_v_size, v_cache.stride(), storage_offset=0)
import xe_addons
xe_addons.quantize_key_value(key, value,
new_k_cache[:, :, cur_length:new_length, :],
new_v_cache[:, :, cur_length:new_length, :])
return new_k_cache, new_v_cache
def restore_fp8_kv_cache(k_cache, v_cache, dtype):
key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype)
value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)