Optimize broadcast for npu llama (#12028)

This commit is contained in:
Yang Wang 2024-09-05 22:28:20 -07:00 committed by GitHub
parent e5581e6ded
commit 58555bd9de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 6 deletions

View file

@ -118,6 +118,7 @@ class _BaseAutoModelClass:
ignore_argument(kwargs, "pipeline_parallel_stages")
optimize_model = kwargs.pop("optimize_model", False)
max_output_len = kwargs.pop("max_output_len", 1024)
max_output_len = max_output_len - 1
max_prompt_len = kwargs.pop("max_prompt_len", 512)
inter_pp = kwargs.pop("inter_pp", None)
intra_pp = kwargs.pop("intra_pp", None)

View file

@ -530,7 +530,7 @@ def run_decode(
with torch.inference_mode():
while True:
dist.broadcast(control, src=0)
dist.broadcast(control, src=0, async_op=False)
if control.item() == -2:
break
elif control.item() == -1:
@ -595,6 +595,8 @@ class DecodeRunner:
self.output_queues = []
self.decoder_processes = []
self.forward_signal = torch.tensor(0, dtype=torch.int)
for rank in range(1, world_size):
input_q = mp.Queue()
output_q = mp.Queue()
@ -643,21 +645,20 @@ class DecodeRunner:
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
t0 = time.perf_counter()
if self.cache_past_key_value != past_key_value:
control = torch.tensor(-1, dtype=torch.int)
dist.broadcast(control, src=0)
for i in range(len(self.decoder_processes)):
self.input_queues[i].put(past_key_value)
control = torch.tensor(0, dtype=torch.int)
dist.broadcast(control, src=0)
t0 = time.perf_counter()
dist.broadcast(self.forward_signal, src=0, async_op=True)
t1 = time.perf_counter()
hidden_states = hidden_states.to(torch.float16)
dist.send(hidden_states, dst=1)
past_key_value.expand(self.transpose_value_cache)
dist.recv(hidden_states, src=self.world_size - 1)
t1 = time.perf_counter()
t2 = time.perf_counter()
return hidden_states, past_key_value
def shutdown(self):