From 7cc01fdc8642edcd1d2226e24b7adbda46d1b547 Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Thu, 12 Dec 2024 21:42:00 -0800 Subject: [PATCH] [NPU] further fix of `new_value_states` (#12538) --- .../llm/src/ipex_llm/transformers/npu_models/mp_models_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 610cbc1e..8f2d2507 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -213,6 +213,7 @@ class LLMBaseNNFactory(NNFactory): value_states = new_value_states else: value_states = self.transpose(value_states, [0, 2, 1, 3]) + new_value_states = value_states query_states, key_states = self.apply_rotary_pos_emb( q=query_states, @@ -225,7 +226,6 @@ class LLMBaseNNFactory(NNFactory): head_dim=head_dim, ) new_key_states = key_states - new_value_states = value_states if mode == "decode": key_states = self.concat(past_key, key_states, axis=-2)