[Fix] LLM: Fix condition check error for speculative decoding on CPU (#10402)
Fix condition check error for speculative decoding on CPU
This commit is contained in:
parent
f158b49835
commit
e10de2c42d
1 changed files with 13 additions and 10 deletions
|
|
@ -609,17 +609,20 @@ def speculative_generate(self,
|
||||||
draft_current_input_ids = current_input_ids
|
draft_current_input_ids = current_input_ids
|
||||||
# Target model KV cache to draft model
|
# Target model KV cache to draft model
|
||||||
|
|
||||||
if self.device.type == 'cpu' and (not _enable_ipex):
|
if self.device.type == 'cpu':
|
||||||
# init past_key_values_storage and assign initial fp32 value
|
# init past_key_values_storage and assign initial fp32 value
|
||||||
if step == 1:
|
if _enable_ipex:
|
||||||
past_key_values_storage = \
|
draft_past_key_values = past_key_values
|
||||||
_prepare_past_key_values_storage_cpu(self, past_key_values,
|
else:
|
||||||
max_new_tokens, _enable_ipex)
|
if step == 1:
|
||||||
# each iter cut off cur_len kv_cache from past_key_values1
|
past_key_values_storage = \
|
||||||
draft_past_key_values = \
|
_prepare_past_key_values_storage_cpu(self, past_key_values,
|
||||||
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
max_new_tokens, _enable_ipex)
|
||||||
past_key_values_storage, _enable_ipex)
|
# each iter cut off cur_len kv_cache from past_key_values1
|
||||||
original_draft_past_key_values = draft_past_key_values
|
draft_past_key_values = \
|
||||||
|
_prepare_draft_past_key_values_cpu(self, past_key_values,
|
||||||
|
past_key_values_storage, _enable_ipex)
|
||||||
|
original_draft_past_key_values = draft_past_key_values
|
||||||
else:
|
else:
|
||||||
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
|
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
|
||||||
max_step_draft,
|
max_step_draft,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue