# # Copyright 2016 The BigDL Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Some parts of this file is adapted from # https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/kv_cache.py # which is licensed under the MIT license: # # MIT License # # Copyright (c) 2023 MIT HAN Lab # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import torch def slice1d(x, start, end): return x[:, start:end, ...] def slice2d(x, start, end): return x[:, :, start:end, ...] def slice3d(x, start, end): return x[:, :, :, start:end, ...] DIM_TO_SLICE = { 1: slice1d, 2: slice2d, 3: slice3d, } class StartRecentKVCache: def __init__( self, start_size=4, recent_size=512, k_seq_dim=2, v_seq_dim=2, ): print(f"StartRecentKVCache: {start_size}, {recent_size}") self.start_size = start_size self.recent_size = recent_size self.cache_size = start_size + recent_size self.k_seq_dim = k_seq_dim self.v_seq_dim = v_seq_dim self.k_slice = DIM_TO_SLICE[k_seq_dim] self.v_slice = DIM_TO_SLICE[v_seq_dim] def __call__(self, past_key_values): if past_key_values is None: return None seq_len = past_key_values[0][0].size(self.k_seq_dim) if seq_len <= self.cache_size: return past_key_values return [ [ torch.cat( [ self.k_slice(k, 0, self.start_size), self.k_slice(k, seq_len - self.recent_size, seq_len), ], dim=self.k_seq_dim, ), torch.cat( [ self.v_slice(v, 0, self.start_size), self.v_slice(v, seq_len - self.recent_size, seq_len), ], dim=self.v_seq_dim, ), ] for k, v in past_key_values ] def evict_for_space(self, past_key_values, num_coming): if past_key_values is None: return None seq_len = past_key_values[0][0].size(self.k_seq_dim) if seq_len + num_coming <= self.cache_size: return past_key_values return [ [ torch.cat( [ self.k_slice(k, 0, self.start_size), self.k_slice( k, seq_len - self.recent_size + num_coming, seq_len ), ], dim=self.k_seq_dim, ), torch.cat( [ self.v_slice(v, 0, self.start_size), self.v_slice( v, seq_len - self.recent_size + num_coming, seq_len ), ], dim=self.v_seq_dim, ), ] for k, v in past_key_values ] def evict_range(self, past_key_values, start, end): if past_key_values is None: return None seq_len = past_key_values[0][0].size(self.k_seq_dim) assert start <= end and end <= seq_len return [ [ torch.cat( [ self.k_slice(k, 0, start), self.k_slice(k, end, seq_len), ], dim=self.k_seq_dim, ), torch.cat( [ self.v_slice(v, 0, start), self.v_slice(v, end, seq_len), ], dim=self.v_seq_dim, ), ] for k, v in past_key_values ]