[Fix] LLM: Fix condition check error for speculative decoding on CPU (#10402)

Fix condition check error for speculative decoding on CPU
This commit is contained in:
Xiangyu Tian 2024-03-13 16:05:06 +08:00 committed by GitHub
parent f158b49835
commit e10de2c42d

View file

@ -609,8 +609,11 @@ def speculative_generate(self,
draft_current_input_ids = current_input_ids draft_current_input_ids = current_input_ids
# Target model KV cache to draft model # Target model KV cache to draft model
if self.device.type == 'cpu' and (not _enable_ipex): if self.device.type == 'cpu':
# init past_key_values_storage and assign initial fp32 value # init past_key_values_storage and assign initial fp32 value
if _enable_ipex:
draft_past_key_values = past_key_values
else:
if step == 1: if step == 1:
past_key_values_storage = \ past_key_values_storage = \
_prepare_past_key_values_storage_cpu(self, past_key_values, _prepare_past_key_values_storage_cpu(self, past_key_values,