parent
58555bd9de
commit
0d04531ae0
2 changed files with 9 additions and 6 deletions
|
|
@ -92,9 +92,9 @@ Supported models: Llama2-7B, Llama3-8B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, MiniCP
|
|||
|
||||
### Recommended NPU Driver Version for LNL Users
|
||||
#### 32.0.100.2625
|
||||
Supported models: Llama2-7B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, Baichuan2-7B
|
||||
Supported models: Llama2-7B, MiniCPM-1B, Baichuan2-7B
|
||||
#### 32.0.101.2715
|
||||
Supported models: Llama3-8B, MiniCPM-2B
|
||||
Supported models: Llama3-8B, MiniCPM-2B, Qwen2-7B, Qwen2-1.5B
|
||||
|
||||
### Run
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -634,7 +634,7 @@ def run_decode(
|
|||
with torch.inference_mode():
|
||||
while True:
|
||||
|
||||
dist.broadcast(control, src=0)
|
||||
dist.broadcast(control, src=0, async_op=False)
|
||||
if control.item() == -2:
|
||||
break
|
||||
elif control.item() == -1:
|
||||
|
|
@ -708,6 +708,8 @@ class DecodeRunner:
|
|||
self.output_queues = []
|
||||
self.decoder_processes = []
|
||||
|
||||
self.forward_signal = torch.tensor(0, dtype=torch.int)
|
||||
|
||||
for rank in range(1, world_size):
|
||||
input_q = mp.Queue()
|
||||
output_q = mp.Queue()
|
||||
|
|
@ -763,13 +765,14 @@ class DecodeRunner:
|
|||
for i in range(len(self.decoder_processes)):
|
||||
self.input_queues[i].put(past_key_value)
|
||||
|
||||
control = torch.tensor(0, dtype=torch.int)
|
||||
dist.broadcast(control, src=0)
|
||||
t0 = time.perf_counter()
|
||||
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
||||
t1 = time.perf_counter()
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
dist.send(hidden_states, dst=1)
|
||||
past_key_value.expand(self.transpose_value_cache)
|
||||
dist.recv(hidden_states, src=self.world_size - 1)
|
||||
t1 = time.perf_counter()
|
||||
t2 = time.perf_counter()
|
||||
return hidden_states, past_key_value
|
||||
|
||||
def shutdown(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue