Optimize broadcast for npu llama (#12028)
This commit is contained in:
parent
e5581e6ded
commit
58555bd9de
2 changed files with 8 additions and 6 deletions
|
|
@ -118,6 +118,7 @@ class _BaseAutoModelClass:
|
||||||
ignore_argument(kwargs, "pipeline_parallel_stages")
|
ignore_argument(kwargs, "pipeline_parallel_stages")
|
||||||
optimize_model = kwargs.pop("optimize_model", False)
|
optimize_model = kwargs.pop("optimize_model", False)
|
||||||
max_output_len = kwargs.pop("max_output_len", 1024)
|
max_output_len = kwargs.pop("max_output_len", 1024)
|
||||||
|
max_output_len = max_output_len - 1
|
||||||
max_prompt_len = kwargs.pop("max_prompt_len", 512)
|
max_prompt_len = kwargs.pop("max_prompt_len", 512)
|
||||||
inter_pp = kwargs.pop("inter_pp", None)
|
inter_pp = kwargs.pop("inter_pp", None)
|
||||||
intra_pp = kwargs.pop("intra_pp", None)
|
intra_pp = kwargs.pop("intra_pp", None)
|
||||||
|
|
|
||||||
|
|
@ -530,7 +530,7 @@ def run_decode(
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
dist.broadcast(control, src=0)
|
dist.broadcast(control, src=0, async_op=False)
|
||||||
if control.item() == -2:
|
if control.item() == -2:
|
||||||
break
|
break
|
||||||
elif control.item() == -1:
|
elif control.item() == -1:
|
||||||
|
|
@ -595,6 +595,8 @@ class DecodeRunner:
|
||||||
self.output_queues = []
|
self.output_queues = []
|
||||||
self.decoder_processes = []
|
self.decoder_processes = []
|
||||||
|
|
||||||
|
self.forward_signal = torch.tensor(0, dtype=torch.int)
|
||||||
|
|
||||||
for rank in range(1, world_size):
|
for rank in range(1, world_size):
|
||||||
input_q = mp.Queue()
|
input_q = mp.Queue()
|
||||||
output_q = mp.Queue()
|
output_q = mp.Queue()
|
||||||
|
|
@ -643,21 +645,20 @@ class DecodeRunner:
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
if self.cache_past_key_value != past_key_value:
|
if self.cache_past_key_value != past_key_value:
|
||||||
control = torch.tensor(-1, dtype=torch.int)
|
control = torch.tensor(-1, dtype=torch.int)
|
||||||
dist.broadcast(control, src=0)
|
dist.broadcast(control, src=0)
|
||||||
for i in range(len(self.decoder_processes)):
|
for i in range(len(self.decoder_processes)):
|
||||||
self.input_queues[i].put(past_key_value)
|
self.input_queues[i].put(past_key_value)
|
||||||
|
t0 = time.perf_counter()
|
||||||
control = torch.tensor(0, dtype=torch.int)
|
dist.broadcast(self.forward_signal, src=0, async_op=True)
|
||||||
dist.broadcast(control, src=0)
|
t1 = time.perf_counter()
|
||||||
hidden_states = hidden_states.to(torch.float16)
|
hidden_states = hidden_states.to(torch.float16)
|
||||||
dist.send(hidden_states, dst=1)
|
dist.send(hidden_states, dst=1)
|
||||||
past_key_value.expand(self.transpose_value_cache)
|
past_key_value.expand(self.transpose_value_cache)
|
||||||
dist.recv(hidden_states, src=self.world_size - 1)
|
dist.recv(hidden_states, src=self.world_size - 1)
|
||||||
t1 = time.perf_counter()
|
t2 = time.perf_counter()
|
||||||
return hidden_states, past_key_value
|
return hidden_states, past_key_value
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue