diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index d37833bb..323a1d9c 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -609,17 +609,20 @@ def speculative_generate(self, draft_current_input_ids = current_input_ids # Target model KV cache to draft model - if self.device.type == 'cpu' and (not _enable_ipex): + if self.device.type == 'cpu': # init past_key_values_storage and assign initial fp32 value - if step == 1: - past_key_values_storage = \ - _prepare_past_key_values_storage_cpu(self, past_key_values, - max_new_tokens, _enable_ipex) - # each iter cut off cur_len kv_cache from past_key_values1 - draft_past_key_values = \ - _prepare_draft_past_key_values_cpu(self, past_key_values, - past_key_values_storage, _enable_ipex) - original_draft_past_key_values = draft_past_key_values + if _enable_ipex: + draft_past_key_values = past_key_values + else: + if step == 1: + past_key_values_storage = \ + _prepare_past_key_values_storage_cpu(self, past_key_values, + max_new_tokens, _enable_ipex) + # each iter cut off cur_len kv_cache from past_key_values1 + draft_past_key_values = \ + _prepare_draft_past_key_values_cpu(self, past_key_values, + past_key_values_storage, _enable_ipex) + original_draft_past_key_values = draft_past_key_values else: past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values, max_step_draft,