Fix pp logic (#11175)
* only send no none batch and rank1-n sending first * always send first
This commit is contained in:
parent
4127b99ed6
commit
c0f1be6aea
1 changed files with 9 additions and 7 deletions
|
|
@ -413,6 +413,10 @@ class ModelRunner:
|
||||||
cur_batch = None
|
cur_batch = None
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
|
if self.send_buff is not None:
|
||||||
|
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
|
||||||
|
dist.send(self.send_buff, dst=self.next_rank)
|
||||||
|
|
||||||
if self.on_going_batches[0] is not None:
|
if self.on_going_batches[0] is not None:
|
||||||
cur_batch = self.on_going_batches[0]
|
cur_batch = self.on_going_batches[0]
|
||||||
cur_input = None
|
cur_input = None
|
||||||
|
|
@ -464,22 +468,20 @@ class ModelRunner:
|
||||||
if (cur_batch is not None) and cur_batch.stopped:
|
if (cur_batch is not None) and cur_batch.stopped:
|
||||||
cur_batch = None
|
cur_batch = None
|
||||||
|
|
||||||
if self.send_buff is not None:
|
if cur_batch is not None:
|
||||||
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
|
|
||||||
dist.send(self.send_buff, dst=self.next_rank)
|
|
||||||
dist.broadcast_object_list([cur_batch], src=0)
|
dist.broadcast_object_list([cur_batch], src=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if self.send_buff is not None:
|
||||||
|
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
|
||||||
|
dist.send(self.send_buff, dst=self.next_rank)
|
||||||
|
|
||||||
batch_list = [None]
|
batch_list = [None]
|
||||||
dist.broadcast_object_list(batch_list, src=0)
|
dist.broadcast_object_list(batch_list, src=0)
|
||||||
|
|
||||||
cur_batch = batch_list[0]
|
cur_batch = batch_list[0]
|
||||||
cur_input = None
|
cur_input = None
|
||||||
|
|
||||||
if self.send_buff is not None:
|
|
||||||
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
|
|
||||||
dist.send(self.send_buff, dst=self.next_rank)
|
|
||||||
|
|
||||||
if cur_batch is not None:
|
if cur_batch is not None:
|
||||||
if cur_batch.stopped:
|
if cur_batch.stopped:
|
||||||
self.clear_batch(cur_batch.batch_id)
|
self.clear_batch(cur_batch.batch_id)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue