diff --git a/python/llm/src/ipex_llm/transformers/npu_model.py b/python/llm/src/ipex_llm/transformers/npu_model.py index 63487dfa..48df4d48 100644 --- a/python/llm/src/ipex_llm/transformers/npu_model.py +++ b/python/llm/src/ipex_llm/transformers/npu_model.py @@ -118,6 +118,7 @@ class _BaseAutoModelClass: ignore_argument(kwargs, "pipeline_parallel_stages") optimize_model = kwargs.pop("optimize_model", False) max_output_len = kwargs.pop("max_output_len", 1024) + max_output_len = max_output_len - 1 max_prompt_len = kwargs.pop("max_prompt_len", 512) inter_pp = kwargs.pop("inter_pp", None) intra_pp = kwargs.pop("intra_pp", None) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py index 46c4236f..f14dd859 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py @@ -530,7 +530,7 @@ def run_decode( with torch.inference_mode(): while True: - dist.broadcast(control, src=0) + dist.broadcast(control, src=0, async_op=False) if control.item() == -2: break elif control.item() == -1: @@ -595,6 +595,8 @@ class DecodeRunner: self.output_queues = [] self.decoder_processes = [] + self.forward_signal = torch.tensor(0, dtype=torch.int) + for rank in range(1, world_size): input_q = mp.Queue() output_q = mp.Queue() @@ -643,21 +645,20 @@ class DecodeRunner: cache_position: Optional[torch.LongTensor] = None, **kwargs, ): - t0 = time.perf_counter() if self.cache_past_key_value != past_key_value: control = torch.tensor(-1, dtype=torch.int) dist.broadcast(control, src=0) for i in range(len(self.decoder_processes)): self.input_queues[i].put(past_key_value) - - control = torch.tensor(0, dtype=torch.int) - dist.broadcast(control, src=0) + t0 = time.perf_counter() + dist.broadcast(self.forward_signal, src=0, async_op=True) + t1 = time.perf_counter() hidden_states = hidden_states.to(torch.float16) dist.send(hidden_states, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) - t1 = time.perf_counter() + t2 = time.perf_counter() return hidden_states, past_key_value def shutdown(self):