diff --git a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md index 615f8307..7b3bba4f 100644 --- a/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md +++ b/python/llm/example/NPU/HF-Transformers-AutoModels/LLM/README.md @@ -92,9 +92,9 @@ Supported models: Llama2-7B, Llama3-8B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, MiniCP ### Recommended NPU Driver Version for LNL Users #### 32.0.100.2625 -Supported models: Llama2-7B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, Baichuan2-7B +Supported models: Llama2-7B, MiniCPM-1B, Baichuan2-7B #### 32.0.101.2715 -Supported models: Llama3-8B, MiniCPM-2B +Supported models: Llama3-8B, MiniCPM-2B, Qwen2-7B, Qwen2-1.5B ### Run ```bash diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 9ddad939..dc896236 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -634,7 +634,7 @@ def run_decode( with torch.inference_mode(): while True: - dist.broadcast(control, src=0) + dist.broadcast(control, src=0, async_op=False) if control.item() == -2: break elif control.item() == -1: @@ -708,6 +708,8 @@ class DecodeRunner: self.output_queues = [] self.decoder_processes = [] + self.forward_signal = torch.tensor(0, dtype=torch.int) + for rank in range(1, world_size): input_q = mp.Queue() output_q = mp.Queue() @@ -763,13 +765,14 @@ class DecodeRunner: for i in range(len(self.decoder_processes)): self.input_queues[i].put(past_key_value) - control = torch.tensor(0, dtype=torch.int) - dist.broadcast(control, src=0) + t0 = time.perf_counter() + dist.broadcast(self.forward_signal, src=0, async_op=True) + t1 = time.perf_counter() hidden_states = hidden_states.to(torch.float16) dist.send(hidden_states, dst=1) past_key_value.expand(self.transpose_value_cache) dist.recv(hidden_states, src=self.world_size - 1) - t1 = time.perf_counter() + t2 = time.perf_counter() return hidden_states, past_key_value def shutdown(self):