update NPU readme of Qwen2 (#12032)

* update readme

* update broadcast
This commit is contained in:
Ruonan Wang 2024-09-06 00:02:39 -07:00 committed by GitHub
parent 58555bd9de
commit 0d04531ae0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 9 additions and 6 deletions

View file

@ -92,9 +92,9 @@ Supported models: Llama2-7B, Llama3-8B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, MiniCP
### Recommended NPU Driver Version for LNL Users
#### 32.0.100.2625
Supported models: Llama2-7B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, Baichuan2-7B
Supported models: Llama2-7B, MiniCPM-1B, Baichuan2-7B
#### 32.0.101.2715
Supported models: Llama3-8B, MiniCPM-2B
Supported models: Llama3-8B, MiniCPM-2B, Qwen2-7B, Qwen2-1.5B
### Run
```bash

View file

@ -634,7 +634,7 @@ def run_decode(
with torch.inference_mode():
while True:
dist.broadcast(control, src=0)
dist.broadcast(control, src=0, async_op=False)
if control.item() == -2:
break
elif control.item() == -1:
@ -708,6 +708,8 @@ class DecodeRunner:
self.output_queues = []
self.decoder_processes = []
self.forward_signal = torch.tensor(0, dtype=torch.int)
for rank in range(1, world_size):
input_q = mp.Queue()
output_q = mp.Queue()
@ -763,13 +765,14 @@ class DecodeRunner:
for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value)
control = torch.tensor(0, dtype=torch.int)
dist.broadcast(control, src=0)
t0 = time.perf_counter()
dist.broadcast(self.forward_signal, src=0, async_op=True)
t1 = time.perf_counter()
hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1)
t1 = time.perf_counter()
t2 = time.perf_counter()
return hidden_states, past_key_value
def shutdown(self):