Skip to content

Commit e63c2b2

Browse files
H-Huangpytorchmergebot
authored andcommitted
[PP] Initialize P2P communicators on first step (#160210)
Was hitting hangs in multi-node settings and initializing the NCCL communicators needed for batch p2p ops ahead of time fixes this. This change adds extra communication since it communicates a dummy tensor to next and previous stage ranks. However, this is only paid on the first step so it is negligible. Debug history: https://docs.google.com/document/d/1EKVJYmW2hj_VsvDvnSggXhZzJyvMu9dA0iDJWOZAtjY/edit?tab=t.0 Pull Request resolved: #160210 Approved by: https://github.com/wconstab
1 parent 3626ba7 commit e63c2b2

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

torch/distributed/pipelining/schedules.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,13 @@ def __init__(
554554
)
555555

556556
def _initialize_stage(self, args, kwargs):
557+
# Prepare the communication needed for the pipeline schedule execution
558+
# This is needed because during execution we always perform a series of batch P2P ops
559+
# The first call of the batched P2P needs to involve the global group
560+
all_ops: list[dist.P2POp] = []
561+
all_ops.extend(self._stage._get_init_p2p_neighbors_ops())
562+
_wait_batch_p2p(_batch_p2p(all_ops))
563+
557564
self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
558565
if self._has_backward:
559566
self._stage._prepare_backward_infra(self._n_microbatches)
@@ -1428,6 +1435,14 @@ def __init__(
14281435
)
14291436

14301437
def _initialize_stages(self, args: tuple[Any, ...], kwargs):
1438+
# Prepare the communication needed for the pipeline schedule execution
1439+
# This is needed because during execution we always perform a series of batch P2P ops
1440+
# The first call of the batched P2P needs to involve the global group
1441+
all_ops: list[dist.P2POp] = []
1442+
for stage in self._stages:
1443+
all_ops.extend(stage._get_init_p2p_neighbors_ops())
1444+
_wait_batch_p2p(_batch_p2p(all_ops))
1445+
14311446
# may be 'none' value (if this stage sends its output shapes to the next stage via P2P)
14321447
# or real value (if this stage and next stage are on the same device)
14331448
next_stage_args: tuple[Any, ...] = tuple()

torch/distributed/pipelining/stage.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,60 @@ def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]):
935935
f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs
936936
)
937937

938+
def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]:
939+
"""
940+
Get the operations to initialize the p2p communicators between previous and next stages.
941+
This is done so by creating a dummy tensor and sending it to the next stage and receiving
942+
from the previous stage.
943+
"""
944+
ops: list[dist.P2POp] = []
945+
next_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index + 1)
946+
prev_stage_peer_rank = self.stage_index_to_group_rank.get(self.stage_index - 1)
947+
948+
recv_tensor = torch.zeros(1, device=self.device)
949+
send_tensor = torch.tensor(self.stage_index, device=self.device)
950+
# forward
951+
if not self.is_first:
952+
ops.append(
953+
dist.P2POp(
954+
dist.irecv,
955+
recv_tensor,
956+
group_peer=prev_stage_peer_rank,
957+
group=self.group,
958+
)
959+
)
960+
if not self.is_last:
961+
ops.append(
962+
dist.P2POp(
963+
dist.isend,
964+
send_tensor,
965+
group_peer=next_stage_peer_rank,
966+
group=self.group,
967+
)
968+
)
969+
970+
# backward
971+
if not self.is_first:
972+
ops.append(
973+
dist.P2POp(
974+
dist.isend,
975+
send_tensor,
976+
group_peer=prev_stage_peer_rank,
977+
group=self.group,
978+
)
979+
)
980+
if not self.is_last:
981+
ops.append(
982+
dist.P2POp(
983+
dist.irecv,
984+
recv_tensor,
985+
group_peer=next_stage_peer_rank,
986+
group=self.group,
987+
)
988+
)
989+
990+
return ops
991+
938992

939993
class _PipelineStage(_PipelineStageBase):
940994
def __init__(

0 commit comments

Comments
 (0)