Fix hf generate for llama3.2 (#12497)

* fix kv condition]

* meet review
This commit is contained in:
Kai Huang 2024-12-04 17:54:40 +08:00 committed by GitHub
parent ffa9a9e1b3
commit 7d27f134dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -455,7 +455,7 @@ def optimize_llm_single_process(
def prepare_input_ids(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
if past_key_values is not None: # kvcache
if past_key_values and isinstance(past_key_values, bool): # kvcache
input_ids = input_ids[:, -1]
else: # prefill, reset the model here
from .npu_llm_cpp import reset
@ -495,7 +495,7 @@ def causal_lm_forward(
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=1, # just an indicator
past_key_values=True, # just an indicator
hidden_states=None,
attentions=None,
)