-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathschedules.py
2652 lines (2329 loc) · 106 KB
/
schedules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import csv
import itertools
import logging
import re
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from enum import Enum
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
import torch
import torch.distributed as dist
from torch._dynamo import OptimizedModule
from torch.distributed.fsdp import FSDPModule, UnshardHandle
from torch.nn.modules.loss import _Loss
from torch.profiler import record_function
from ._utils import generate_stage_to_rank_mapping
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
from .stage import _PipelineStageBase
if TYPE_CHECKING:
from torch.distributed import Work
__all__ = [
"get_schedule_class",
"PipelineScheduleSingle",
"PipelineScheduleMulti",
"Schedule1F1B",
"ScheduleGPipe",
"ScheduleInterleaved1F1B",
"ScheduleLoopedBFS",
"ScheduleInterleavedZeroBubble",
"ScheduleZBVZeroBubble",
]
logger = logging.getLogger(__name__)
class _ComputationType(Enum):
# TODO(whc) rename to _ActType?
FORWARD = 1
BACKWARD_INPUT = 2
BACKWARD_WEIGHT = 3
UNSHARD = 4
RESHARD = 5
SEND_F = 6
RECV_F = 7
SEND_B = 8
RECV_B = 9
FULL_BACKWARD = 10
def __str__(self):
str_map = {
_ComputationType.FORWARD: "F",
_ComputationType.BACKWARD_INPUT: "I",
_ComputationType.BACKWARD_WEIGHT: "W",
_ComputationType.UNSHARD: "UNSHARD",
_ComputationType.RESHARD: "RESHARD",
_ComputationType.SEND_F: "SEND_F",
_ComputationType.RECV_F: "RECV_F",
_ComputationType.SEND_B: "SEND_B",
_ComputationType.RECV_B: "RECV_B",
_ComputationType.FULL_BACKWARD: "B",
}
return str_map[self]
@staticmethod
def from_str(action):
if action == "F":
return _ComputationType.FORWARD
elif action == "I":
return _ComputationType.BACKWARD_INPUT
elif action == "W":
return _ComputationType.BACKWARD_WEIGHT
elif action == "UNSHARD":
return _ComputationType.UNSHARD
elif action == "RESHARD":
return _ComputationType.RESHARD
elif action == "SEND_F":
return _ComputationType.SEND_F
elif action == "RECV_F":
return _ComputationType.RECV_F
elif action == "SEND_B":
return _ComputationType.SEND_B
elif action == "RECV_B":
return _ComputationType.RECV_B
elif action == "B":
return _ComputationType.FULL_BACKWARD
else:
raise RuntimeError(f"Invalid computation type {action}")
FORWARD = _ComputationType.FORWARD
BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT
BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT
UNSHARD = _ComputationType.UNSHARD
RESHARD = _ComputationType.RESHARD
SEND_F = _ComputationType.SEND_F
RECV_F = _ComputationType.RECV_F
SEND_B = _ComputationType.SEND_B
RECV_B = _ComputationType.RECV_B
FULL_BACKWARD = _ComputationType.FULL_BACKWARD
# Convenience shorthand for compute actions only since they are used in 'simple schedule format'
F = FORWARD
I = BACKWARD_INPUT
W = BACKWARD_WEIGHT
B = FULL_BACKWARD
# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
_action_regex = re.compile(
r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)"
)
class _Action(NamedTuple):
stage_index: int
computation_type: _ComputationType
microbatch_index: Optional[int] = None
def __repr__(self):
repr = str(self.stage_index)
repr += str(self.computation_type)
if self.microbatch_index is not None:
repr += str(self.microbatch_index)
return repr
@staticmethod
def from_str(action_string: str):
"""
Reverse of __repr__
String should be formatted as [stage][action type][(microbatch)]
e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
"""
action_string = action_string.strip()
if match := _action_regex.match(action_string):
stage_index, computation_type, microbatch_index = match.groups()
return _Action(
int(stage_index),
_ComputationType.from_str(computation_type),
int(microbatch_index) if len(microbatch_index) else None,
)
elif action_string == "":
return None
raise RuntimeError(
f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
)
def _format_pipeline_order(
pipeline_order: dict[int, list[Optional[_Action]]],
error_step_number: Optional[int] = None,
) -> str:
"""
Formats the pipeline order in a timestep (row) x rank (column) grid of actions
and returns the formatted string.
If `error_step_number` is passed in, an additional label will be added to signify which step
that it is erroring on.
"""
# don't mutate the original
pipeline_order = copy.deepcopy(pipeline_order)
# Replace None with ""
for rank in pipeline_order:
for i in range(len(pipeline_order[rank])):
if pipeline_order[rank][i] is None:
# TODO make a real 'None action' that prints as empty string and make mypy happy
pipeline_order[rank][i] = "" # type: ignore[call-overload]
# Calculate the maximum number of steps across all ranks
num_steps = max(len(actions) for actions in pipeline_order.values())
step_labels = [
"Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
]
# Sorting the dictionary by keys and retrieving values in that order
rank_actions = [
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)
rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
# Calculate the maximum length of each column, considering labels
max_lengths = [
max(len(str(item)) if item is not None else 0 for item in col)
for col in zip(step_labels, *transposed_actions)
]
# Format the header row with rank labels
header_row = " " * (len(step_labels[0]) + 2) + " ".join(
f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
)
# Format each row with its corresponding label
formatted_rows = [
f"{label}: "
+ " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
+ (
" <-- ERROR HERE"
if error_step_number is not None
and int(label.split()[1]) == error_step_number
else ""
)
for label, row in zip(step_labels, transposed_actions)
]
# Join the rows into a single string
formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
return formatted_table
class _PipelineSchedule(ABC):
def __init__(
self,
n_microbatches: int,
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
# From arguments
self._n_microbatches = n_microbatches
self._loss_fn = loss_fn
# See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti`
self.scale_grads = scale_grads
# Chunking specification for positional inputs. (default: `None`)
self._args_chunk_spec = args_chunk_spec
# Chunking specification for keyword inputs. (default: `None`)
self._kwargs_chunk_spec = kwargs_chunk_spec
self._output_merge_spec = output_merge_spec
"""
# args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
# They are used to convert batch to microbatches in `step(x)`. See
# `TensorChunkSpec` for helper methods for creating them.
"""
# Derived
self._has_backward = self._loss_fn is not None
# Holds the losses for each microbatch.
self._internal_losses: list[torch.Tensor] = []
logger.info("Using %s", self.__class__.__name__)
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
if stage.is_last and self._has_backward:
loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
self._internal_losses.append(loss)
def _maybe_get_loss(self, stage, mb_index):
valid_index = 0 <= mb_index < len(self._internal_losses)
if stage.is_last and self._has_backward and valid_index:
return self._internal_losses[mb_index]
elif len(self._internal_losses) != 0 and not valid_index:
raise RuntimeError(
f"Loss for microbatch {mb_index} is not available. "
f"Available losses for microbatches: {self._internal_losses}"
)
else:
return None
def _update_losses(self, stages, losses):
"""
Update the losses to those in the internal state
"""
# if stages not a list turn into a list
if not isinstance(stages, list):
stages = [stages]
contains_last_stage = any(stage.is_last for stage in stages)
# Return losses if there is a container passed in
if contains_last_stage and losses is not None:
if len(self._internal_losses) != self._n_microbatches:
raise RuntimeError(
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
)
# Clean external container first
losses.clear()
# Copy internal losses to external container
losses.extend(self._internal_losses)
self._internal_losses.clear()
@abstractmethod
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the schedule
implementation.
Args:
microbatches: list of microbatch args.
"""
raise NotImplementedError
@abstractmethod
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
microbatches according to the schedule implementation.
args: positional arguments to the model (as in non-pipeline case).
kwargs: keyword arguments to the model (as in non-pipeline case).
target: target for the loss function.
losses: a list to store the losses for each microbatch.
"""
raise NotImplementedError
def _check_inputs(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Pre-process/check inputs
"""
def check_type_and_len(mbs, name: str):
if not isinstance(mbs, list):
raise TypeError(f"{name} must be a list but got a {type(mbs)}")
if len(mbs) != self._n_microbatches:
raise ValueError(
f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
)
if arg_mbs is not None:
check_type_and_len(arg_mbs, "arg_mbs")
else:
arg_mbs = [()] * self._n_microbatches
if kwarg_mbs is not None:
check_type_and_len(kwarg_mbs, "kwarg_mbs")
else:
kwarg_mbs = [{}] * self._n_microbatches
if target_mbs is not None:
check_type_and_len(target_mbs, "target_mbs")
if losses is not None:
if not isinstance(losses, list):
raise TypeError(f"losses must be a list but got a {type(losses)}")
return arg_mbs, kwarg_mbs
def _compute_loss(self, output, target):
return self._loss_fn(output, target) # type: ignore[misc]
def _split_inputs(
self,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
):
"""
Splits a full-batch input into chunks (i.e. microbatches) and returns
the chunks
"""
if args or kwargs:
args_split, kwargs_split = split_args_kwargs_into_chunks(
args,
kwargs,
self._n_microbatches,
self._args_chunk_spec,
self._kwargs_chunk_spec,
)
return args_split, kwargs_split
else:
# Empty inputs (e.g. when called on middle stages)
# Return a list of empty tuples/dicts with matching length as chunks
return [()] * self._n_microbatches, [{}] * self._n_microbatches
def _merge_outputs(self, output_chunks: list[Any]) -> Any:
"""
Merge output chunks back to a batch state.
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
"""
return merge_chunks(
output_chunks,
self._output_merge_spec,
)
def _batch_p2p(p2p_ops: list[dist.P2POp], desc: Optional[str] = None):
"""
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
"""
if len(p2p_ops) == 0:
return None
desc_str = f"{desc}, " if desc else ""
logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
return dist.batch_isend_irecv(p2p_ops).pop()
def _sorted_batch_p2p(
p2p_ops: list[dist.P2POp], desc: Optional[str] = None
) -> dict[int, dist.Work]:
"""
Sorts the list of P2P ops by the peer rank, and then calls
batch_isend_irecv. Return a dictionary of works by peer rank. This function
helps us avoid hangs in case of skip connections.
"""
# Arrange p2p_ops by peer rank:
# int is the peer rank;
# List is the list of ops towards the peer
ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
work_by_peer: dict[int, dist.Work] = {}
if len(p2p_ops) == 0:
return work_by_peer
# Classify the ops by peer rank
for op in p2p_ops:
ops_by_peer[op.peer].append(op)
# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
for peer, ops in sorted(ops_by_peer.items()):
work_by_peer[peer] = _batch_p2p(ops, desc=desc)
return work_by_peer
class PipelineScheduleSingle(_PipelineSchedule):
"""
Base class for single-stage schedules.
Implements the `step` method.
Derived classes should implement `_step_microbatches`.
Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
or sum losses (scale_grads=False).
"""
def __init__(
self,
stage: _PipelineStageBase,
n_microbatches: int,
loss_fn: Optional[Callable] = None,
args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
scale_grads: bool = True,
):
# Init parent
super().__init__(
n_microbatches=n_microbatches,
loss_fn=loss_fn,
args_chunk_spec=args_chunk_spec,
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
scale_grads=scale_grads,
)
# Self attributes
self._stage = stage
self._num_stages = stage.num_stages
# Set the same has_backward flag for stage object
self._stage.has_backward = self._has_backward
self._stage_initialized = False
if n_microbatches < self._num_stages:
raise ValueError(
f"Number of microbatches ({n_microbatches}) must be greater than \
or equal to the number of stages ({self._num_stages})."
)
def _initialize_stage(self, args, kwargs):
self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
if self._has_backward:
self._stage._prepare_backward_infra(self._n_microbatches)
self._stage_initialized = True
def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
microbatches according to the schedule implementation.
args: positional arguments to the model (as in non-pipeline case).
kwargs: keyword arguments to the model (as in non-pipeline case).
target: target for the loss function.
losses: a list to store the losses for each microbatch.
"""
# Clean per iteration
self._stage.clear_runtime_states()
# Split inputs into microbatches
args_split, kwargs_split = self._split_inputs(args, kwargs)
# Split target into microbatches
if target is not None:
targets_split = list(torch.tensor_split(target, self._n_microbatches))
else:
targets_split = None
# Run microbatches
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
# Return merged results per original format
if self._stage.is_last:
return self._merge_outputs(self._stage.output_chunks)
else:
return None
class _ScheduleForwardOnly(PipelineScheduleSingle):
"""
The forward-only schedule.
Will go through all the microbatches and perform only the forward pass
"""
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule
"""
if target_mbs is not None or losses is not None:
raise RuntimeError(
"Forward-only schedule does not support loss computation"
)
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stage_initialized:
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Delay send waits
fwd_sends_to_wait: list[dist.Work] = []
# Run microbatches
for i in range(self._n_microbatches):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops(i)
works = _sorted_batch_p2p(ops, desc="fwd_recv")
for work in works.values():
work.wait()
self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
ops = self._stage.get_fwd_send_ops(i)
works = _sorted_batch_p2p(ops, desc="fwd_send")
fwd_sends_to_wait.extend(works.values())
logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
# Wait for all forward sends to finish
# This should not have performance impact because by the time the first
# backward arrives all the forward sends should have been finished.
for work in fwd_sends_to_wait:
work.wait()
class ScheduleGPipe(PipelineScheduleSingle):
"""
The GPipe schedule.
Will go through all the microbatches in a fill-drain manner.
"""
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the GPipe schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stage_initialized:
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Delay send waits
fwd_sends_to_wait: list[dist.Work] = []
# Run microbatches
for i in range(self._n_microbatches):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops(i)
works = _sorted_batch_p2p(ops, desc="fwd_recv")
for work in works.values():
work.wait()
output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
ops = self._stage.get_fwd_send_ops(i)
works = _sorted_batch_p2p(ops, desc="fwd_send")
fwd_sends_to_wait.extend(works.values())
logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
self._maybe_compute_loss(self._stage, output, target_mbs, i)
# Wait for all forward sends to finish
# This should not have performance impact because by the time the first
# backward arrives all the forward sends should have been finished.
for work in fwd_sends_to_wait:
work.wait()
# No loss function, no need to run backward
if not self._has_backward:
return
# Run backward
# Delay send waits
bwd_sends_to_wait: list[dist.Work] = []
for i in range(self._n_microbatches):
with record_function(f"Backward {i}"):
ops = self._stage.get_bwd_recv_ops(i)
works = _sorted_batch_p2p(ops, desc="bwd_recv")
for work in works.values():
work.wait()
loss = self._maybe_get_loss(self._stage, i)
self._stage.backward_one_chunk(
i,
loss=loss,
last_backward=i == self._n_microbatches - 1,
)
ops = self._stage.get_bwd_send_ops(i)
works = _sorted_batch_p2p(ops, desc="bwd_send")
bwd_sends_to_wait.extend(works.values())
logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
self._stage.scale_grads(
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
)
# Return losses if there is a container passed in
self._update_losses(self._stage, losses)
# Wait for all backward sends to finish
for work in bwd_sends_to_wait:
work.wait()
class Schedule1F1B(PipelineScheduleSingle):
"""
The 1F1B schedule.
Will perform one forward and one backward on the microbatches in steady state.
"""
def _step_microbatches(
self,
arg_mbs: Optional[list] = None,
kwarg_mbs: Optional[list] = None,
target_mbs: Optional[list] = None,
losses: Optional[list] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the 1F1B schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
if not self._stage_initialized:
self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
# Last stage has 1 warmup, second-to-last 2 warmups, ...
# first stage `num_stages` warmups
warmup_chunks = min(
self._n_microbatches,
self._num_stages - self._stage.stage_index,
)
# Chunk counters
fwd_mb_index = 0
bwd_mb_index = 0
# Warmup phase
send_work = None
fwd_sends = []
for _ in range(warmup_chunks):
# Receive activations
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
if recv_work := _batch_p2p(fwd_recvs, desc="fwd_recv"):
recv_work.wait()
# Compute
output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
) # type: ignore[index]
# Clear previous chunk's forward sends (hopefully they have well
# finished, otherwise, we are heavily communication bound, in which
# case it doesn't create a lot of benefit to compute next chunk
# eagerly either)
if send_work:
send_work.wait()
# Send activations
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
if fwd_mb_index != warmup_chunks - 1:
# Safe to fire
send_work = _batch_p2p(fwd_sends, desc="fwd_send")
# otherwise:
# The last foward send is left for fuse with first 1B in 1B1F below
# Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
fwd_mb_index += 1
# Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
# 1B1F phase
while True: # Don't worry, we have a break inside
# We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
# Now, we need to fire the fwd_sends and bwd_recvs together
if fuse_work := _batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"):
fuse_work.wait()
# Backward one chunk
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(
bwd_mb_index,
loss=loss,
last_backward=bwd_mb_index == self._n_microbatches - 1,
)
# Get the bwd send ops, but don't fire, to be fused with the 1F below
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
bwd_mb_index += 1
if fwd_mb_index == self._n_microbatches:
# We are done with 1B1F, so break with some left-over bwd_sends
break
# We prepare 1F of the `1B1F`
fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
# Fuse it with bwd_sends above
if fuse_work := _batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"):
fuse_work.wait()
# Now do the fwd
output = self._stage.forward_one_chunk(
fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
) # type: ignore[index]
# Compute loss
self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
# Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
fwd_mb_index += 1
# Remember we still have some bwd_sends left over after the break? Now it is time to fire it
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
# Cooldown
while bwd_mb_index < self._n_microbatches:
# prepare bwd recv ops
bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
if recv_work := _batch_p2p(bwd_recvs, desc="bwd_recv"):
recv_work.wait()
# Backward one chunk
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(
bwd_mb_index,
loss=loss,
last_backward=bwd_mb_index == self._n_microbatches - 1,
)
# Clear previous chunk's backward sends (hopefully they have well finished)
if send_work:
send_work.wait()
# Get the bwd send ops, fire it
bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
send_work = _batch_p2p(bwd_sends, desc="bwd_send")
bwd_mb_index += 1
self._stage.scale_grads(
grad_scale_factor=self._n_microbatches if self.scale_grads else 1
)
# Wait for the last backward send to finish
if send_work:
send_work.wait()
# Return losses if there is a container passed in
self._update_losses(self._stage, losses)
def _add_unshard_reshard(
compute_actions: list[Optional[_Action]],
max_active_stages: int = 3,
) -> list[_Action]:
"""Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP.
UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
RESHARD does the opposite, releasing memory (but doing no commmunication)
We abandon the "timestep lock" during lowering
max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
3 stages is probably the thing we want?
(to account for having one f and one b active, and something else prefetching?)
"""
def next_stage_indices(
count: int, next_actions: list[Optional[_Action]]
) -> list[int]:
"""Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
seen: set[int] = set()
ret: list[int] = []
for a in next_actions:
if a is not None and a.stage_index not in seen:
seen.add(a.stage_index)
ret.append(a.stage_index)
if len(ret) == count:
break
return ret
active_stages: set[int] = set()
fsdp_aware_actions: list[_Action] = []
def _unshard(stage_index: int):
active_stages.add(stage_index)
fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
def _reshard(stage_index: int):
active_stages.remove(stage_index)
fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
for i, action in enumerate(compute_actions):
if action is None:
continue
# We prefetch the next N stages we'll see, dropping existing stages to make room
next_n = next_stage_indices(max_active_stages, compute_actions[i:])
# Fetch needs to be ordered correctly, so don't use a set
fetch = list(filter(lambda s: s not in active_stages, next_n))
# Unclear what the best policy is for eviction, but we can maintain order so we do
evict = list(filter(lambda s: s not in next_n, active_stages))
# logger.debug(
# "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
# i,
# active_stages,
# fetch,
# evict,
# )
for stage in evict:
_reshard(stage)
for stage in fetch:
_unshard(stage)
fsdp_aware_actions.append(action)
return fsdp_aware_actions
def _merge_bw(
compute_actions: list[Optional[_Action]],
) -> list[_Action]:
"""Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
(note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
in some cases.
"""
merged_actions = []
while compute_actions:
action = compute_actions.pop(0)
if action is None:
continue
while len(compute_actions) and (next_action := compute_actions[0]) is None:
# remove any None actions between 'action' and 'next_action'
compute_actions.pop(0)
if (
action.computation_type == BACKWARD_INPUT
and next_action is not None
and next_action.computation_type == BACKWARD_WEIGHT
and action.stage_index == next_action.stage_index
and action.microbatch_index == next_action.microbatch_index
):
merged_actions.append(
_Action(action.stage_index, FULL_BACKWARD, action.microbatch_index)
)
compute_actions.pop(0)
else:
merged_actions.append(action)
return merged_actions
def _add_send_recv(
compute_actions: dict[int, list[_Action]],
stage_to_rank: Callable[[int], int],
num_stages: int,
) -> dict[int, list[_Action]]:
comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
def _has_comms(action: _Action) -> bool:
if action.computation_type == F:
return action.stage_index != num_stages - 1 and stage_to_rank(
action.stage_index + 1
) != stage_to_rank(action.stage_index)
elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
return action.stage_index != 0 and stage_to_rank(
action.stage_index - 1
) != stage_to_rank(action.stage_index)
return False
def _get_comms(action: _Action) -> tuple[_Action, _Action]:
assert _has_comms(action), f"{action} is not a valid comm action"
stage_idx = action.stage_index
ctype = action.computation_type
mb_idx = action.microbatch_index
send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
return send, recv
def _ready_to_schedule(
action: Optional[_Action], prev_actions: set[_Action]
) -> bool:
"""We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
This helps ensure a sane (non-hanging) ordering of sends and recvs.
But it also means we might not be able to schedule our next compute action yet.
"""
if action is None:
return True
elif action.computation_type == F and not action.stage_index == 0:
if (
_Action(action.stage_index, RECV_F, action.microbatch_index)
in prev_actions
):
return True
elif (
_Action(action.stage_index - 1, F, action.microbatch_index)
in prev_actions
):
return True
return False
elif (
action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
and not action.stage_index == num_stages - 1
):
if (
_Action(action.stage_index, RECV_B, action.microbatch_index)
in prev_actions
):
return True
elif (
_Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
in prev_actions
):
return True
elif (
_Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
in prev_actions
):
return True
return False
else:
return True
while compute_actions:
progress = False
# go in order of ranks even if dict keys aren't ordered
for rank in sorted(compute_actions):
assert len(compute_actions[rank]) > 0, (
f"{rank=}, {len(compute_actions[rank])=}"