From 3daad242b833b880755031969e1767f8a490125a Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Mon, 22 Apr 2024 14:05:50 +0800 Subject: [PATCH] Fix No module named 'transformers.cache_utils' with transformers < 4.36 (#10835) * update sdp condition * update * fix * fix 431 error * revert sdp & style fix * fix * meet comments --- .../llm/src/ipex_llm/transformers/lookup.py | 3 ++- .../src/ipex_llm/transformers/speculative.py | 23 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/lookup.py b/python/llm/src/ipex_llm/transformers/lookup.py index 8cce463a..41d40f28 100644 --- a/python/llm/src/ipex_llm/transformers/lookup.py +++ b/python/llm/src/ipex_llm/transformers/lookup.py @@ -316,7 +316,8 @@ def lookup_generate(self, if max_of_max_matched != max_matched: output_ids = output_ids[:, :max_matched] new_cache_size = max_of_max_matched - max_matched - past_key_values = _crop_past_key_values(self, past_key_values, new_cache_size) + past_key_values = _crop_past_key_values(self, past_key_values, + new_cache_size) input_ids = torch.cat((input_ids, output_ids), dim=-1) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 2a2a1ddf..1f7a84bd 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -451,19 +451,20 @@ def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_l def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=False): - from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache - if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)): - if hasattr(past_key_values, "_seen_tokens"): - past_key_values._seen_tokens -= new_cache_size - else: - past_key_values.seen_tokens -= new_cache_size + if version.parse(trans_version) >= version.parse("4.36.0"): + from ipex_llm.transformers.kv import DynamicFp8Cache, DynamicNormalCache + if isinstance(past_key_values, (DynamicFp8Cache, DynamicNormalCache)): + if hasattr(past_key_values, "_seen_tokens"): + past_key_values._seen_tokens -= new_cache_size + else: + past_key_values.seen_tokens -= new_cache_size - for i, k in enumerate(past_key_values.key_cache): - past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :] - for i, v in enumerate(past_key_values.value_cache): - past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :] + for i, k in enumerate(past_key_values.key_cache): + past_key_values.key_cache[i] = k[:, :, :-new_cache_size, :] + for i, v in enumerate(past_key_values.value_cache): + past_key_values.value_cache[i] = v[:, :, :-new_cache_size, :] - return past_key_values + return past_key_values if _enable_ipex: cur_len = past_key_values[0][0].size(1)