From 6693e8ab04c8518a9e9affc62f5a9752eb24a74f Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Fri, 11 Apr 2025 11:26:15 +0800 Subject: [PATCH] Deepseek kv / sdp support (#13068) * update kv * fix * fix style --- python/llm/src/ipex_llm/transformers/kv.py | 52 ++++++++++++++++++- .../ipex_llm/transformers/models/common.py | 6 +-- .../src/ipex_llm/transformers/models/utils.py | 43 +++++++++++++++ 3 files changed, 97 insertions(+), 4 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index b0e52282..1dc04228 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -22,7 +22,8 @@ import math from .models.utils import ( init_fp8_kv_cache, append_fp8_kv_cache, - init_kv_cache, append_kv_cache, extend_kv_cache + init_kv_cache, append_kv_cache, extend_kv_cache, + init_unbalanced_fp8_kv_cache, append_unbalanced_fp8_kv_cache, ) from typing import Optional, Dict, Tuple, Any, List from transformers.cache_utils import DynamicCache @@ -151,6 +152,55 @@ class DynamicNormalCache(DynamicCache): return past_key_values +class DynamicUnbalancedFp8Cache(DynamicCache): + def __init__(self, num_hidden_layers: Optional[int] = None) -> None: + # ignore num_hidden_layers to fix transformers >= 4.45 + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]]=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # fix converting empty DynamicCache in transformers >= 4.45 + if key_states == []: + return key_states, value_states + + batch_size, num_heads, seq_len, k_head_dim = key_states.shape + _, _, _, v_head_dim = value_states.shape + + if layer_idx == 0: + if hasattr(self, "_seen_tokens"): + # 4.39 uses `_seen_tokens` + self._seen_tokens += seq_len + else: + # 4.37 uses `seen_tokens` + self.seen_tokens += seq_len + + # Update the cache + if len(self.key_cache) <= layer_idx: + k_cache, v_cache = init_unbalanced_fp8_kv_cache( + batch_size, num_heads, seq_len, k_head_dim, v_head_dim, + device=key_states.device, + ) + k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache, + key_states, value_states) + + self.key_cache.append(k_cache) + self.value_cache.append(v_cache) + else: + k_cache = self.key_cache[layer_idx] + v_cache = self.value_cache[layer_idx] + k_cache, v_cache = append_unbalanced_fp8_kv_cache(k_cache, v_cache, + key_states, value_states) + self.key_cache[layer_idx] = k_cache + self.value_cache[layer_idx] = v_cache + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 2776c3aa..27337c94 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -273,11 +273,11 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, else: attn_output = xe_addons.sdp_causal(query, key, value, mask, scale) elif seq_length != kv_length and seq_length <= 32: - # todo: add scale support + # todo: add further scale support if key.dtype == torch.uint8: - attn_output = xe_addons.sdp_fp8(query, key, value, mask) + attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale) else: - attn_output = xe_addons.sdp(query, key, value, mask) + attn_output = xe_addons.sdp(query, key, value, mask, scale) else: if key.dtype == torch.uint8: attn_output = xe_addons.sdp_fp8_non_causal(query, key, value, mask, scale) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 0e3e897c..d43fc51f 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -138,6 +138,49 @@ def append_fp8_kv_cache(k_cache, v_cache, key, value): return new_k_cache, new_v_cache +def init_unbalanced_fp8_kv_cache(batch_size, num_heads, current_length, + k_head_dim, v_head_dim, device): + # for case which k head dim is different from v head dim + max_length = current_length + FP8_KV_ALLOC_LENGTH + + k_cache_storage = torch.empty(batch_size, num_heads, max_length, k_head_dim, + dtype=torch.uint8, device=device) + k_cache = k_cache_storage.as_strided((batch_size, num_heads, 0, k_head_dim), + k_cache_storage.stride(), storage_offset=0) + + v_cache_storage = torch.empty(batch_size, num_heads, max_length, v_head_dim, + dtype=torch.uint8, device=device) + v_cache = v_cache_storage.as_strided((batch_size, num_heads, 0, v_head_dim), + v_cache_storage.stride(), storage_offset=0) + return k_cache, v_cache + + +def append_unbalanced_fp8_kv_cache(k_cache, v_cache, key, value): + batch_size, num_heads, cur_length, k_head_dim = k_cache.shape + _, _, _, v_head_dim = v_cache.shape + new_length = cur_length + key.size(2) + new_k_size = (batch_size, num_heads, new_length, k_head_dim) + new_v_size = (batch_size, num_heads, new_length, v_head_dim) + + if k_cache.stride(1) < new_length * k_cache.size(3): + new_k_cache, new_v_cache = init_unbalanced_fp8_kv_cache(batch_size, num_heads, new_length, + k_head_dim, v_head_dim, key.device) + new_k_cache = new_k_cache.as_strided(new_k_size, new_k_cache.stride(), storage_offset=0) + new_v_cache = new_v_cache.as_strided(new_v_size, new_v_cache.stride(), storage_offset=0) + new_k_cache[:, :, :cur_length, :] = k_cache + new_v_cache[:, :, :cur_length, :] = v_cache + else: + new_k_cache = k_cache.as_strided(new_k_size, k_cache.stride(), storage_offset=0) + new_v_cache = v_cache.as_strided(new_v_size, v_cache.stride(), storage_offset=0) + + import xe_addons + xe_addons.quantize_key_value(key, value, + new_k_cache[:, :, cur_length:new_length, :], + new_v_cache[:, :, cur_length:new_length, :]) + + return new_k_cache, new_v_cache + + def restore_fp8_kv_cache(k_cache, v_cache, dtype): key_states = torch.empty(k_cache.shape, device=k_cache.device, dtype=dtype) value_states = torch.empty(v_cache.shape, device=v_cache.device, dtype=dtype)