diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index afcb8f68..5c5ca0ce 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -601,7 +601,7 @@ class FusedLlamaLowBitMultiDecoderlayer(torch.nn.Module): x_np[2].ctypes.data_as(ctypes.c_void_p), ) t0 = time.perf_counter() - backend_lib.run_decoders(models_ptr, inputs_ptr, 2, 3) + backend_lib.run_decoders(models_ptr, inputs_ptr, self.intra_stages, 3) t1 = time.perf_counter() hidden_states = self.backend_decoders[-1].torch_out[0]