fix lookahead with transformers >= 4.36 (#10808)

This commit is contained in:
Yishuo Wang 2024-04-19 16:24:56 +08:00 committed by GitHub
parent 34ff07b689
commit 57edf2033c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -451,6 +451,20 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l
def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False): def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False):
from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache
if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)):
if hasattr(past_key_values, "_seen_tokens"):
past_key_values._seen_tokens -= new_cache_size
else:
past_key_values.seen_tokens -= new_cache_size
for i, k in enumerate(past_key_values.key_cache):
past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :]
for i, v in enumerate(past_key_values.value_cache):
past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :]
return past_key_values
if _enable_ipex: if _enable_ipex:
cur_len = past_key_values[0][0].size(1) cur_len = past_key_values[0][0].size(1)
delta = new_cache_size delta = new_cache_size