From 57edf2033cfd174d6c913dfdba8f0cad73e8f573 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 19 Apr 2024 16:24:56 +0800 Subject: [PATCH] fix lookahead with transformers >= 4.36 (#10808) --- .../llm/src/ipex_llm/transformers/speculative.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 9d8880cc..2a2a1ddf 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -451,6 +451,20 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False): + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache + if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)): + if hasattr(past_key_values, "_seen_tokens"): + past_key_values._seen_tokens -= new_cache_size + else: + past_key_values.seen_tokens -= new_cache_size + + for i, k in enumerate(past_key_values.key_cache): + past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :] + for i, v in enumerate(past_key_values.value_cache): + past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :] + + return past_key_values + if _enable_ipex: cur_len = past_key_values[0][0].size(1) delta = new_cache_size