diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index a1dac609..42cf72e3 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -225,6 +225,7 @@ class LLMBaseNNFactory(NNFactory): head_dim=head_dim, ) new_key_states = key_states + new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2)