Fix pp logic (#11175)

* only send no none batch and rank1-n sending first

* always send first
This commit is contained in:
Wang, Jian4 2024-05-30 16:40:59 +08:00 committed by GitHub
parent 4127b99ed6
commit c0f1be6aea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -413,6 +413,10 @@ class ModelRunner:
cur_batch = None cur_batch = None
if self.rank == 0: if self.rank == 0:
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
if self.on_going_batches[0] is not None: if self.on_going_batches[0] is not None:
cur_batch = self.on_going_batches[0] cur_batch = self.on_going_batches[0]
cur_input = None cur_input = None
@ -464,22 +468,20 @@ class ModelRunner:
if (cur_batch is not None) and cur_batch.stopped: if (cur_batch is not None) and cur_batch.stopped:
cur_batch = None cur_batch = None
if self.send_buff is not None: if cur_batch is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
dist.broadcast_object_list([cur_batch], src=0) dist.broadcast_object_list([cur_batch], src=0)
else: else:
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
batch_list = [None] batch_list = [None]
dist.broadcast_object_list(batch_list, src=0) dist.broadcast_object_list(batch_list, src=0)
cur_batch = batch_list[0] cur_batch = batch_list[0]
cur_input = None cur_input = None
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
if cur_batch is not None: if cur_batch is not None:
if cur_batch.stopped: if cur_batch.stopped:
self.clear_batch(cur_batch.batch_id) self.clear_batch(cur_batch.batch_id)