parent
58555bd9de
commit
0d04531ae0
2 changed files with 9 additions and 6 deletions
|
|
@ -92,9 +92,9 @@ Supported models: Llama2-7B, Llama3-8B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, MiniCP
|
||||||
|
|
||||||
### Recommended NPU Driver Version for LNL Users
|
### Recommended NPU Driver Version for LNL Users
|
||||||
#### 32.0.100.2625
|
#### 32.0.100.2625
|
||||||
Supported models: Llama2-7B, Qwen2-1.5B, Qwen2-7B, MiniCPM-1B, Baichuan2-7B
|
Supported models: Llama2-7B, MiniCPM-1B, Baichuan2-7B
|
||||||
#### 32.0.101.2715
|
#### 32.0.101.2715
|
||||||
Supported models: Llama3-8B, MiniCPM-2B
|
Supported models: Llama3-8B, MiniCPM-2B, Qwen2-7B, Qwen2-1.5B
|
||||||
|
|
||||||
### Run
|
### Run
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -634,7 +634,7 @@ def run_decode(
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
dist.broadcast(control, src=0)
|
dist.broadcast(control, src=0, async_op=False)
|
||||||
if control.item() == -2:
|
if control.item() == -2:
|
||||||
break
|
break
|
||||||
elif control.item() == -1:
|
elif control.item() == -1:
|
||||||
|
|
@ -708,6 +708,8 @@ class DecodeRunner:
|
||||||
self.output_queues = []
|
self.output_queues = []
|
||||||
self.decoder_processes = []
|
self.decoder_processes = []
|
||||||
|
|
||||||
|
self.forward_signal = torch.tensor(0, dtype=torch.int)
|
||||||
|
|
||||||
for rank in range(1, world_size):
|
for rank in range(1, world_size):
|
||||||
input_q = mp.Queue()
|
input_q = mp.Queue()
|
||||||
output_q = mp.Queue()
|
output_q = mp.Queue()
|
||||||
|
|
@ -763,13 +765,14 @@ class DecodeRunner:
|
||||||
for i in range(len(self.decoder_processes)):
|
for i in range(len(self.decoder_processes)):
|
||||||
self.input_queues[i].put(past_key_value)
|
self.input_queues[i].put(past_key_value)
|
||||||
|
|
||||||
control = torch.tensor(0, dtype=torch.int)
|
t0 = time.perf_counter()
|
||||||
dist.broadcast(control, src=0)
|
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
||||||
|
t1 = time.perf_counter()
|
||||||
hidden_states = hidden_states.to(torch.float16)
|
hidden_states = hidden_states.to(torch.float16)
|
||||||
dist.send(hidden_states, dst=1)
|
dist.send(hidden_states, dst=1)
|
||||||
past_key_value.expand(self.transpose_value_cache)
|
past_key_value.expand(self.transpose_value_cache)
|
||||||
dist.recv(hidden_states, src=self.world_size - 1)
|
dist.recv(hidden_states, src=self.world_size - 1)
|
||||||
t1 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
return hidden_states, past_key_value
|
return hidden_states, past_key_value
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue