Optimize broadcast for npu llama (#12028)
This commit is contained in:
parent
e5581e6ded
commit
58555bd9de
2 changed files with 8 additions and 6 deletions
|
|
@ -118,6 +118,7 @@ class _BaseAutoModelClass:
|
|||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||
optimize_model = kwargs.pop("optimize_model", False)
|
||||
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||
max_output_len = max_output_len - 1
|
||||
max_prompt_len = kwargs.pop("max_prompt_len", 512)
|
||||
inter_pp = kwargs.pop("inter_pp", None)
|
||||
intra_pp = kwargs.pop("intra_pp", None)
|
||||
|
|
|
|||
|
|
@ -530,7 +530,7 @@ def run_decode(
|
|||
with torch.inference_mode():
|
||||
while True:
|
||||
|
||||
dist.broadcast(control, src=0)
|
||||
dist.broadcast(control, src=0, async_op=False)
|
||||
if control.item() == -2:
|
||||
break
|
||||
elif control.item() == -1:
|
||||
|
|
@ -595,6 +595,8 @@ class DecodeRunner:
|
|||
self.output_queues = []
|
||||
self.decoder_processes = []
|
||||
|
||||
self.forward_signal = torch.tensor(0, dtype=torch.int)
|
||||
|
||||
for rank in range(1, world_size):
|
||||
input_q = mp.Queue()
|
||||
output_q = mp.Queue()
|
||||
|
|
@ -643,21 +645,20 @@ class DecodeRunner:
|
|||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
if self.cache_past_key_value != past_key_value:
|
||||
control = torch.tensor(-1, dtype=torch.int)
|
||||
dist.broadcast(control, src=0)
|
||||
for i in range(len(self.decoder_processes)):
|
||||
self.input_queues[i].put(past_key_value)
|
||||
|
||||
control = torch.tensor(0, dtype=torch.int)
|
||||
dist.broadcast(control, src=0)
|
||||
t0 = time.perf_counter()
|
||||
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
||||
t1 = time.perf_counter()
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
dist.send(hidden_states, dst=1)
|
||||
past_key_value.expand(self.transpose_value_cache)
|
||||
dist.recv(hidden_states, src=self.world_size - 1)
|
||||
t1 = time.perf_counter()
|
||||
t2 = time.perf_counter()
|
||||
return hidden_states, past_key_value
|
||||
|
||||
def shutdown(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue