diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 610cbc1e..8f2d2507 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -213,6 +213,7 @@ class LLMBaseNNFactory(NNFactory): value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) + new_value_states = value_states query_states, key_states = self.apply_rotary_pos_emb( q=query_states, @@ -225,7 +226,6 @@ class LLMBaseNNFactory(NNFactory): head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2)