-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy pathreducer.cpp
2451 lines (2220 loc) · 95.8 KB
/
reducer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/csrc/distributed/c10d/reducer.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/default_comm_hooks.hpp>
#include <functional>
#include <c10/core/DeviceGuard.h>
#include <c10/core/ScalarType.h>
#include <c10/core/StreamGuard.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/autograd/utils/grad_layout_contract.h>
#include <torch/csrc/autograd/utils/lambda_post_hook.h>
#include <torch/csrc/distributed/c10d/comm.hpp>
#include <torch/csrc/distributed/c10d/logger.hpp>
#include <utility>
namespace c10d {
namespace {
constexpr int kUnsetDivFactor = -1;
// Macro that wraps TORCH_CHECK with DDP logging.
#define REDUCER_CHECK(cond, logger_, ...) \
if (C10_UNLIKELY_OR_CONST(!(cond))) { \
if (!logger_.expired()) { \
logger_.lock()->set_error_and_log(__VA_ARGS__); \
} \
TORCH_CHECK(false, ##__VA_ARGS__); \
}
} // namespace
C10_DEFINE_TYPED_REGISTRY(
TimerRegistry,
c10::DeviceType,
Timer,
std::unique_ptr,
c10::Device)
namespace {
class CpuTimer : public Timer {
public:
explicit CpuTimer(c10::Device /* unused */) {}
std::optional<int64_t> measureDifference(Event start, Event end) override {
int64_t start_time = getTimeRef(start);
int64_t end_time = getTimeRef(end);
// If cpu_end_time is not recorded in this iteration,
// avg_time will return invalid value.
// For some cases like DDP runs on non-sync mode, backward compute
// end time can not be recorded in this iteration and thus can not
// calculate the valid avg_time.
// In this case, skip calculating the avg_time and return.
if (end_time < start_time) {
return std::nullopt;
}
return end_time - start_time;
}
};
C10_REGISTER_TYPED_CLASS(TimerRegistry, c10::kCPU, CpuTimer)
std::vector<at::Tensor> extractTensors(const c10::IValue& result) {
if (result.isPyObject()) {
return result.toPyObjectHolder()->extractTensors();
}
TORCH_INTERNAL_ASSERT(
result.isTensor() || result.isTensorList(),
"expected the hook result is either a Tensor or a TensorList found ",
result.tagKind());
if (result.isTensor()) {
return {result.toTensor()};
}
return result.toTensorVector();
}
} // namespace
Reducer::Reducer(
std::vector<at::Tensor> params,
std::vector<std::vector<size_t>> bucket_indices,
c10::intrusive_ptr<c10d::ProcessGroup> process_group,
std::vector<bool> expect_sparse_gradients,
int64_t bucket_bytes_cap,
bool find_unused_parameters,
bool gradient_as_bucket_view,
std::unordered_map<size_t, std::string> param_names,
int64_t first_bucket_bytes_cap,
bool skip_all_reduce_unused_params)
: params_(std::move(params)),
process_group_(std::move(process_group)),
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
expect_autograd_hooks_(false),
require_finalize_(false),
next_bucket_(0),
has_marked_unused_parameters_(false),
find_unused_parameters_(find_unused_parameters),
gradient_as_bucket_view_(gradient_as_bucket_view),
local_used_map_reduced_(false),
num_iterations_(0),
num_bwd_calls_(0),
first_autograd_hook_called_(false),
num_buckets_ready_(0),
num_buckets_reduced_(0),
has_rebuilt_bucket_(false),
bucket_bytes_cap_(bucket_bytes_cap),
div_factor_(kUnsetDivFactor),
static_graph_(false),
skip_all_reduce_unused_params_(skip_all_reduce_unused_params),
comm_hook_(nullptr),
ddp_debug_level_(debug_level()),
param_names_(std::move(param_names)),
first_bucket_bytes_cap_(first_bucket_bytes_cap) {
C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer");
TORCH_INTERNAL_ASSERT(!params_.empty(), "Expected at least one parameter.");
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
LOG(INFO) << "Reducer initialized with bucket_bytes_cap: "
<< bucket_bytes_cap_
<< " first_bucket_bytes_cap: " << first_bucket_bytes_cap;
}
// Check whether the module is multi_device_module
{
std::set<int> unique_devices;
for (const auto& v : params_) {
auto device_idx = int(v.device().index());
if (unique_devices.find(device_idx) == unique_devices.end()) {
unique_devices.insert(device_idx);
if (unique_devices.size() > 1) {
is_multi_device_module_ = true;
break;
}
}
}
}
// For CUDA, record events only for single device module.
c10::Device device = params_[0].device();
if (!(device.is_cuda() && is_multi_device_module_)) {
timer_ = TimerRegistry()->Create(device.type(), device);
}
// If `expect_sparse_gradients` is not specified, initialize it such that
// we do not expect sparse gradients for any parameter.
if (expect_sparse_gradients_.empty()) {
expect_sparse_gradients_ = std::vector<bool>(params_.size(), false);
}
TORCH_INTERNAL_ASSERT(expect_sparse_gradients_.size() == params_.size());
// Initialize variable bucketing.
// This can be reinitialized later after capturing runtime information.
{
std::lock_guard<std::mutex> lock(mutex_);
initialize_buckets(std::move(bucket_indices));
}
// All variables are expected to have their `grad_fn` set to the gradient
// accumulation function (since they are leafs in the autograd graph).
// We store pointers to these functions such that we can check if they are
// used in an autograd pass. If they are not, we know their grad tensors
// can be marked as ready for reduction.
{
const auto variable_count = params_.size();
grad_accumulators_.resize(variable_count);
for (const auto variable_index : c10::irange(variable_count)) {
auto& variable = params_[variable_index];
// The gradient accumulator function is lazily initialized once.
// Therefore we can use its presence in the autograd graph as
// evidence that the parameter has participated in an iteration.
auto grad_accumulator = torch::autograd::impl::grad_accumulator(variable);
#ifndef _WIN32
using torch::distributed::autograd::ThreadLocalDistAutogradContext;
#endif
// Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back(
grad_accumulator->add_post_hook(std::make_unique<
torch::autograd::utils::
LambdaPostHook>(
[this, variable_index](
const torch::autograd::variable_list& outputs,
const torch::autograd::variable_list& /* unused */) {
#ifndef _WIN32
this->rpc_context_.set(
ThreadLocalDistAutogradContext::getContextPtr());
#endif
this->autograd_hook(variable_index);
return outputs;
},
[=](torch::autograd::CompiledNodeArgs& args) {
TORCH_INTERNAL_ASSERT(
"Compiled autograd is not compatible with C++ DDP Reducer, please use torch._dynamo.config.optimize_ddp=\"python_reducer\".");
})),
grad_accumulator);
// Map raw function pointer to parameter index.
// This is used later on when the autograd graph is traversed
// to check for parameters for which no gradient is computed, if
// find_unused_parameters=True.
// Note that the mapping of gradient accumulator to variable should be
// one to one as we deduplicate shared parameters before constructing
// Reducer.
if (find_unused_parameters_) {
gradAccToVariableMap_[grad_accumulator.get()] = variable_index;
}
numGradHooksTriggeredMap_[variable_index] = 0;
// The gradient accumulator is stored as weak_ptr in the autograd
// metadata of the variable, so we have to keep it alive here for
// the raw pointer to be valid.
REDUCER_CHECK(
grad_accumulators_[variable_index] == nullptr,
logger_,
c10::str(
"Reducer tried to register duplicate grad accumulator for variable ",
variable_index));
grad_accumulators_[variable_index] = std::move(grad_accumulator);
}
}
// Initialize backward stats vector.
{
const auto variable_count = params_.size();
backward_stats_.resize(variable_count);
}
// See Note [Skip allreducing local_used_map_dev]
if (find_unused_parameters_) {
initialize_local_used_map();
}
}
// Note [Skip allreducing local_used_map_dev]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// If find_unused_parameters_ is set to false, there is no need to allreduce
// local_used_map_dev_, because all parameters will be reduced anyway.
// Therefore, we can avoid allocating memory for local_used_map and
// local_used_map_dev_ if find_unused_parameters_ is false.
// Note [DDP Communication Hook]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// If DDP communication hook is not registered, the reducer reduces the buckets
// by just calling allreduce. If registered, it calls the hook and uses future
// work handle. If registered, reducer also skips dividing grads by world size.
// The reason for this is that the communication hook is expected to completely
// override how we perform communication and the user should have complete
// control over how the grads are handled.
//
// DDP communication hook is an enhancement that provides a hook which can be
// used to override how DDP communicates gradients across ranks, this can be
// used for algorithms like Gradient Compression/GossipGrad. This hook can be
// registered from Python API using `register_comm_hook`. `PythonCommHook`
// enables registering a Python hook and is a subclass of `CommHookInterface`.
// Additionally, there are also some built-in C++ hook implementations that can
// be specified by calling `register_builtin_comm_hook` from Python API.
Reducer::~Reducer() noexcept(false) {
remove_autograd_hooks();
}
bool Reducer::dynamic_graph_find_unused() {
return !static_graph_ && find_unused_parameters_;
}
bool Reducer::static_graph_first_iteration() {
return static_graph_ && num_bwd_calls_ == 1;
}
bool Reducer::static_graph_after_first_iteration() {
return static_graph_ && num_bwd_calls_ > 1;
}
bool Reducer::ddp_graph_static() {
std::lock_guard<std::mutex> lock(mutex_);
return ddp_graph_static_;
}
void Reducer::initialize_local_used_map() {
const auto variable_count = params_.size();
at::TensorOptions options;
options = options.dtype(at::kInt);
// Deliberately don't pin the memory even if local_used_map_dev_ will
// be cuda. See Note [local_used_map_ -> local_used_map_dev copying]
local_used_map_ = at::zeros({static_cast<long>(variable_count)}, options);
// This tensor needs to be on the same device as the replica params because
// backend such as NCCL may not support CPU tensors, and hence it might not
// work if we always put it on CPU. The dist backend for MTIA doesn't support
// int32 allreduce for now, so it has to be placed on CPU.
options = options.device(
(params_[0].is_mtia()) ? c10::Device(c10::DeviceType::CPU)
: params_[0].device());
local_used_map_dev_ = at::empty({static_cast<long>(variable_count)}, options);
}
void Reducer::check_grad_layout(
const at::Tensor& grad,
const at::Tensor& bucket_view) {
// Ensure that the gradient type matches the bucket type, or mixed precision
// type if we are training with mixed precision.
auto type = mixed_precision_param_dtype_
? *mixed_precision_param_dtype_
: bucket_view.options().dtype().toScalarType();
REDUCER_CHECK(
grad.options().dtype().toScalarType() == type,
logger_,
c10::str(
"Expected ", type, ", got ", grad.options().dtype().toScalarType()));
TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
// AccumulateGrad doesn't HAVE to obey the grad layout contract.
// The penalty for disobedience is reduced performance, not numerical
// death. Warnings here help diagnose poor DDP performance.
if (grad.strides() != bucket_view.strides()) {
TORCH_WARN_ONCE(
"Grad strides do not match bucket view strides. "
"This may indicate grad was not created according to the "
"gradient layout contract, or that the param's strides "
"changed since DDP was constructed. This is not an error, "
"but may impair performance.\n"
"grad.sizes() = ",
grad.sizes(),
", strides() = ",
grad.strides(),
"\n",
"bucket_view.sizes() = ",
bucket_view.sizes(),
", strides() = ",
bucket_view.strides());
}
if (!gradient_as_bucket_view_) {
TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view));
}
}
void Reducer::mark_variable_ready_dense(size_t variable_index) {
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& variable = bucket.variables[bucket_index.intra_bucket_index];
auto& bucket_view = bucket.bucket_views_in[bucket_index.intra_bucket_index];
// Copy the contents of the gradient tensor to the corresponding part of the
// bucket's flattened gradient tensor.
// If the gradient is not set, we assume it wasn't computed as part of the
// current backwards pass, and we zero the part of the bucket it would
// otherwise hold.
runGradCallbackForVariable(variable, [&](auto& grad) {
if (grad.defined()) {
this->check_grad_layout(grad, bucket_view);
// When gradient_as_bucket_view_ is false, or even when
// gradient_as_bucket_view_ is true, in rare cases users may set grad to
// be None after every iteration. In these cases, grad and bucket_view are
// pointing to different storages and thus need to copy grads to
// bucket_view. If gradient_as_bucket_view_ is set as true, let grad point
// to bucket_view. If grad has already been set as views of buckets in
// previous iterations, no copy is needed.
if (!grad.is_alias_of(bucket_view)) {
if (comm_hook_ == nullptr) {
auto wrapped =
at::native::wrapped_scalar_tensor(double(1.) / div_factor_);
if (!grad.requires_grad()) {
// Divides while copying into the bucket view to save one scan over
// all the input parameters.
RECORD_FUNCTION(
"torch::distributed::reducer::mul_out",
std::vector<c10::IValue>({bucket_view}))
at::mul_out(bucket_view, grad, wrapped);
} else {
// If DDP is running with create_graph=True, gradients require_grad
// themselves in order to compute higher order derivatives. However,
// DDP will not sync up these gradients currently (see
// https://github.com/pytorch/pytorch/issues/63812).
C10_LOG_EVERY_N(WARNING, 1000)
<< "Using DistributedDataParallel with create_graph=True "
<< " is not well-supported. The higher-order gradient will "
<< " not be synchronized across ranks, and backpropagation "
<< " through all_reduce operations will not occur. If you require "
<< " DDP to work with higher-order gradients for your use case, "
<< " please ping https://github.com/pytorch/pytorch/issues/63929";
auto div_result = at::mul(grad, wrapped);
RECORD_FUNCTION(
"torch::distributed::reducer::copy_",
std::vector<c10::IValue>({bucket_view}))
bucket_view.copy_(div_result);
}
} else {
RECORD_FUNCTION(
"torch::distributed::reducer::copy_",
std::vector<c10::IValue>({bucket_view}))
bucket_view.copy_(grad);
}
if (gradient_as_bucket_view_) {
// Let grad point to bucket_view buffer.
grad = bucket_view;
// The grad is modified and need to be written back.
return true;
}
} else {
// If grad and bucket view point to the same storage, no need to copy.
if (comm_hook_ == nullptr) {
bucket_view.div_(div_factor_);
}
}
} else {
// Gradient is undefined. When find_unused_parameters=True, ensure it is
// not marked as locally used, otherwise we will be allreducing zero's
// instead of not touching .grad field of parameter.
if (this->dynamic_graph_find_unused() ||
this->static_graph_first_iteration()) {
REDUCER_CHECK(
local_used_map_[variable_index].item<int>() == 0,
logger_,
"Encountered gradient which is undefined, but still allreduced by "
"DDP reducer. This indicates a bug in DDP implementation, please "
"report a bug with a repro to PyTorch.");
}
bucket_view.zero_();
}
// The grad is not modified and doesn't need to be written back.
return false;
});
}
void Reducer::mark_variable_ready_sparse(size_t variable_index) {
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& variable = bucket.variables[bucket_index.intra_bucket_index];
runGradCallbackForVariable(variable, [&](auto& grad) {
REDUCER_CHECK(
grad.defined(), logger_, "Expected sparse gradient to be defined.");
REDUCER_CHECK(
grad.options().layout() == c10::kSparse,
logger_,
"Expected variable to have sparse gradient.");
// Copy the indices of sparse metadata
if (sparse_metadata_) {
grad = grad.coalesce();
REDUCER_CHECK(
!param_names_.empty(), logger_, "No parameter names were found");
std::string& param_name = param_names_[variable_index];
auto iter = sparse_metadata_->find(param_name);
REDUCER_CHECK(
iter != sparse_metadata_->end(),
logger_,
"param: " + param_name + " not found in sparse metadata");
bucket.sparse_tensor_indices =
iter->second.to(at::kLong).unsqueeze(0).to(grad.device());
auto indices = at::searchsorted(
bucket.sparse_tensor_indices.value(), grad.indices(), false, false);
// For indices we are using the ones set by sparse_metadata
grad = at::sparse_coo_tensor(indices, grad.values(), grad.sizes());
}
// Sparse tensors cannot be grouped together with other sparse tensors in a
// single reduction operation like we can for dense tensors. Therefore, the
// `offsets` and `lengths` vectors in the bucket struct are empty, and
// there is no pre-existing accumulation tensor.
// Directly assign the sparse tensor to the `gradients` field.
bucket.gradients = grad;
// If no DDP comm hook is registered, the allreduce only sums up the
// value, and a separate division is required.
if (comm_hook_ == nullptr) {
bucket.gradients.div_(div_factor_);
}
// The grad is modified in place and needs to be written back.
return true;
});
}
std::vector<c10d::GradBucket> Reducer::get_grad_buckets(
bool return_zero_tensors) const {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<c10d::GradBucket> gradBuckets;
gradBuckets.reserve(buckets_.size());
for (const auto i : c10::irange(buckets_.size())) {
auto& bucket = buckets_[i];
auto variables_for_bucket = get_variables_for_bucket(i, bucket);
gradBuckets.emplace_back(
i,
buckets_.size(),
return_zero_tensors ? at::zeros_like(bucket.gradients)
: bucket.gradients,
bucket.offsets,
bucket.lengths,
bucket.sizes_vec,
variables_for_bucket,
std::nullopt);
}
return gradBuckets;
}
void Reducer::set_forward_pass_work_handle(
c10::intrusive_ptr<c10d::Work> forwardPassWorkHandle,
bool useStaticWorldSize) {
std::lock_guard<std::mutex> lock(mutex_);
forwardPassWorkHandle_.workHandle = std::move(forwardPassWorkHandle);
forwardPassWorkHandle_.useStaticWorldSize = useStaticWorldSize;
}
at::Tensor Reducer::get_local_used_map_on_device() const {
std::lock_guard<std::mutex> lock(mutex_);
return local_used_map_dev_;
}
void Reducer::push_rebuilt_params_for_all_indices() {
std::lock_guard<std::mutex> lock(mutex_);
if (!should_rebuild_buckets() || !rebuilt_param_indices_.empty()) {
return;
}
const auto variable_count = params_.size();
for (const auto variable_index : c10::irange(variable_count)) {
push_rebuilt_params(variable_index);
}
}
void Reducer::push_rebuilt_params(const size_t& index) {
rebuilt_params_.push_back(params_[index]);
rebuilt_param_indices_.push_back(static_cast<int64_t>(index));
}
void Reducer::set_divide_factor() {
// If it was scheduled, wait on allreduce in forward pass that tells us
// division factor based on no. of currently participating processes.
if (div_factor_ == kUnsetDivFactor) {
div_factor_ = process_group_->getSize();
auto& workHandle = forwardPassWorkHandle_.workHandle;
if (workHandle && !forwardPassWorkHandle_.useStaticWorldSize) {
workHandle->wait();
// PyProcessGroup::PyWork doesn't expose value, so fetch it from the
// future
auto results = extractTensors(workHandle->getFuture()->value());
// Guard against the results being empty
TORCH_INTERNAL_ASSERT(!results.empty());
at::Tensor& res = results.front();
div_factor_ = res.item().to<int>();
}
}
}
// This is called before training and converts the gradients to the dtype they
// should be reduced in.
void Reducer::set_mixed_precision_param_dtype(c10::ScalarType dtype) {
mixed_precision_param_dtype_ = dtype;
for (auto& bucket : buckets_) {
bucket.gradients = bucket.gradients.to(dtype);
}
}
// Right now delay_all_reduce is only called when static_graph_=true and
// num_iterations_==1.
void Reducer::delay_all_reduce() {
std::lock_guard<std::mutex> lock(this->mutex_);
if (should_collect_runtime_stats()) {
record_backward_compute_end_time();
record_backward_comm_start_time();
}
// launch all reduce local used map
all_reduce_local_used_map();
// prepare to set unused_parameters_, if it is static graph,
// unused_parameters_ will not change after 1st iteration.
unused_parameters_.clear();
require_finalize_ = true;
// copy all gradients to buckets
for (const auto variable_index : c10::irange(params_.size())) {
// set unused_parameters_
if (numGradHooksTriggeredMap_[variable_index] == 0) {
unused_parameters_.push_back(variable_index);
}
set_divide_factor();
if (expect_sparse_gradients_[variable_index]) {
mark_variable_ready_sparse(variable_index);
} else {
mark_variable_ready_dense(variable_index);
}
}
// To avoid confusion around why static graph is picking up
// some parameters as unused on a rank vs not, we log
// unused parameter names for each rank for better
// debugability when TORCH_DISTRIBUTED_DEBUG is set to
// INFO or DETAIL
if (ddp_debug_level_ != c10d::DebugLevel::Off) {
// construct one string to output
std::ostringstream unused_params_stream;
for (const auto& unused_index : unused_parameters_) {
auto param_name = param_names_.find(unused_index);
TORCH_INTERNAL_ASSERT(
param_name != param_names_.end(),
"Expected to find parameter name from unused parameters map in debug mode.");
// Add the param_name
unused_params_stream << "{" << param_name->second << "," << unused_index
<< "}";
}
// Each rank prints out all the unused parameters detected
if (!unused_parameters_.empty()) {
LOG(INFO) << "[Rank " << process_group_->getRank() << "]: "
<< "Parameter(s) (in the format of {param_name, index}): "
<< unused_params_stream.str()
<< " is(are) unused during first iteration. Since"
<< " static_graph=True is enabled for DDP, we expect"
<< " this set of unused parameters to remain consistent"
<< " on this rank throughout the training.";
}
}
// launch all reduces for all buckets
for (auto& bucket : buckets_) {
all_reduce_bucket(bucket);
}
finalize_backward();
}
void Reducer::set_logger(std::weak_ptr<c10d::Logger> logger) {
logger_ = std::move(logger);
}
// The function `autograd_hook` is called after the gradient for a
// model parameter has been accumulated into its gradient tensor.
// This function is only to be called from the autograd thread.
void Reducer::autograd_hook(size_t index) {
std::lock_guard<std::mutex> lock(this->mutex_);
if (!first_autograd_hook_called_) {
first_autograd_hook_called_ = true;
num_bwd_calls_++;
}
// See Note [Skip allreducing local_used_map_dev]
if (dynamic_graph_find_unused() || static_graph_first_iteration()) {
// Since it gets here, this param has been used for this iteration. We want
// to mark it in local_used_map_. During no_sync session, the same var can
// be set multiple times, which is OK as does not affect correctness. As
// long as it is used once during no_sync session, it is marked as used.
// Only set it as locally used if the grad is defined. Otherwise, hooks can
// be fired with undefined grads, such as when not all outputs are used in
// DDP when computing loss. In this case, we don't want to mark it as
// locally used to ensure we don't touch the parameter's .grad field.
auto& variable = get_param_from_index(index);
runGradCallbackForVariable(variable, [&](auto& grad) {
if (grad.defined()) {
local_used_map_[static_cast<int64_t>(index)] = 1;
}
// The gradient is never modified.
return false;
});
}
if (static_graph_first_iteration()) {
numGradHooksTriggeredMap_[index] += 1;
return;
}
// Ignore if we don't expect to be called.
// This may be the case if the user wants to accumulate gradients
// for number of iterations before reducing them.
if (!expect_autograd_hooks_) {
return;
}
grad_ready_order_indices_.push_back(static_cast<int64_t>(index));
// If `find_unused_parameters_` is true there may be model parameters that
// went unused when computing the model output, they won't be part of the
// autograd graph, and won't receive gradients. These parameters are
// discovered in the `prepare_for_backward` function and their indexes stored
// in the `unused_parameters_` vector.
if (!has_marked_unused_parameters_) {
has_marked_unused_parameters_ = true;
for (const auto& unused_index : unused_parameters_) {
mark_variable_ready(unused_index);
}
}
// Rebuild bucket only if 1) it is the first time to rebuild bucket 2)
// static_graph_ is true or find_unused_parameters_ is false,
// 3) this backward pass needs to run allreduce.
// Here, we just dump tensors and their parameter indices into
// rebuilt_params_ and rebuilt_param_indices_ based on gradient arriving
// order, and then at the end of finalize_backward(), buckets will be
// rebuilt based on rebuilt_params_ and rebuilt_param_indices_, and then
// will be broadcasted and initialized.
// If it is static graph, after 1st iteration, check if a variable
// is ready for communication based on numGradHooksTriggeredMap_.
if (static_graph_after_first_iteration()) {
REDUCER_CHECK(
numGradHooksTriggeredMapPerIteration_[index] > 0,
logger_,
"Your training graph has changed in this iteration, ",
"e.g., one parameter is unused in first iteration, but ",
"then got used in the second iteration. this is not ",
"compatible with static_graph set to True.");
if (--numGradHooksTriggeredMapPerIteration_[index] == 0) {
if (should_rebuild_buckets()) {
push_rebuilt_params(index);
}
// Finally mark variable for which this function was originally called.
mark_variable_ready(index);
}
} else {
if (should_rebuild_buckets()) {
push_rebuilt_params(index);
}
// Finally mark variable for which this function was originally called.
mark_variable_ready(index);
}
}
void Reducer::all_reduce_local_used_map() {
// See Note [Skip allreducing local_used_map_dev]
// H2D from local_used_map_ to local_used_map_dev_
if (local_used_map_dev_.is_cuda() || local_used_map_dev_.is_privateuseone()) {
// Note [local_used_map_ -> local_used_map_dev copying]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// We do async H2D to avoid the blocking overhead. The async copy and
// allreduce respect the current stream, so will be sequenced
// correctly.
//
// Correct sequencing with respect to host operations is also
// essential. The H2D copy_ is stream ordered, while the host's
// changes to local_used_map_ are host ordered. If a large backlog of
// cuda/privateuseone-stream work pushes the copy_ far into the future, and
// if no blocking calls occur between now and finalize_backward()** such
// that finalize_backward() re-zeroes local_used_map_ on the host
// before the stream executes the copy_, copy_ will read those zeros
// instead of the values we thought we told it to read here. Copying
// local_used_map_ to a pinned temporary (which the pinned caching
// allocator should supply asynchronously) avoids this nasty, rare
// race condition.
//
// ** In the hoped-for case where all params are used, DDP itself
// won't do any blocking work between now and the re-zeroing, so the
// danger is real.
//
// Defensively ensures local_used_map_tmp is distinct from
// local_used_map_
auto local_used_map_tmp = at::native::empty_like(
local_used_map_,
c10::optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
local_used_map_.options().layout_opt(),
local_used_map_.options().device_opt(),
true /* pinned_memory */);
// Paranoid asserts here because in some workloads, the pinned
// allocator behaves in a way we don't understand, and may be bugged.
// See https://github.com/pytorch/pytorch/pull/54474
TORCH_INTERNAL_ASSERT(local_used_map_tmp.is_pinned());
TORCH_INTERNAL_ASSERT(
local_used_map_tmp.data_ptr() != local_used_map_.data_ptr());
local_used_map_tmp.copy_(local_used_map_);
local_used_map_dev_.copy_(local_used_map_tmp, true);
} else if (local_used_map_dev_.is_mtia()) {
// MTIA probably will have special logic in the future, following code might
// be changed drastically. Therefore, a new if case is created for MTIA, for
// now, the implementation is similar to the CUDA/privateuseone one, except
// for the pin memory step.
auto local_used_map_tmp = at::native::empty_like(
local_used_map_,
c10::optTypeMetaToScalarType(local_used_map_.options().dtype_opt()),
local_used_map_.options().layout_opt(),
local_used_map_.options().device_opt());
local_used_map_tmp.copy_(local_used_map_);
local_used_map_dev_.copy_(local_used_map_tmp, true);
} else {
local_used_map_dev_.copy_(local_used_map_, true);
}
std::vector<at::Tensor> temp_local_used_map_dev_vec_ = {local_used_map_dev_};
local_used_work_ = process_group_->allreduce(temp_local_used_map_dev_vec_);
}
at::Tensor& Reducer::get_param_from_index(size_t index) {
const auto& bucket_index = variable_locators_[index];
auto& bucket = buckets_[bucket_index.bucket_index];
// Cannot simply access variable via `bucket.variables[variable_index]` since
// return value is used in `runGradCallbackForVariable()` which does not
// accept const tensors.
auto& variable = bucket.variables[bucket_index.intra_bucket_index];
return variable;
}
void Reducer::checkAndRaiseMarkedTwiceError(size_t index) {
// Something is wrong if all variables contained in this bucket have
// already been marked as ready.
// We don't expect the same variable to be marked ready twice.
bool marked_twice =
perIterationReadyParams_.find(index) != perIterationReadyParams_.end();
if (marked_twice) {
// Report index of param that has been marked twice. In debug mode, also
// report fully qualified parameter name.
auto param_name = param_names_.find(index);
const bool found_param_name = param_name != param_names_.end();
TORCH_INTERNAL_ASSERT(
ddp_debug_level_ == c10d::DebugLevel::Off || found_param_name,
"Expected to find parameter name in debug mode.");
std::string paramInfo = c10::str(
"Parameter at index ",
index,
found_param_name ? c10::str(" with name ", param_name->second) : "",
" has been marked as ready twice. This means that multiple autograd engine ",
" hooks have fired for this particular parameter during this iteration.");
// param_names_ is empty in debug mode.
if (!found_param_name) {
paramInfo += c10::str(
" You can set the environment variable TORCH_DISTRIBUTED_DEBUG to either",
" INFO or DETAIL to print parameter names for further debugging.");
}
std::string common_error = c10::str(
"Expected to mark a variable ready only once. ",
"",
"This error is caused by one of the following reasons: ",
"1) Use of a module parameter outside the `forward` function. ",
"Please make sure model parameters are not shared across multiple ",
"concurrent forward-backward passes. or try to use _set_static_graph() ",
"as a workaround if this module graph does not change ",
"during training loop.",
"2) Reused parameters in multiple reentrant backward passes. For ",
"example, if you use multiple `checkpoint` functions to wrap the ",
"same part of your model, it would result in the same set of ",
"parameters been used by different reentrant backward passes ",
"multiple times, and hence marking a variable ready multiple times. ",
"DDP does not support such use cases in default. You can try to ",
"use _set_static_graph() as a workaround if your module graph ",
"does not change over iterations.");
common_error += c10::str("\n", paramInfo);
REDUCER_CHECK(
has_marked_unused_parameters_,
logger_,
common_error,
"3) Incorrect unused parameter detection. The return value of the ",
"`forward` function is inspected by the distributed data parallel ",
"wrapper to figure out if any of the module's parameters went ",
"unused. For unused parameters, DDP would not expect gradients from ",
"then. However, if an unused parameter becomes part of the autograd ",
"graph at a later point in time (e.g., in a reentrant backward when ",
"using `checkpoint`), the gradient will show up unexpectedly. If all ",
"parameters in the model participate in the backward pass, you can ",
"disable unused parameter detection by passing the keyword argument ",
"`find_unused_parameters=False` to ",
"`torch.nn.parallel.DistributedDataParallel`. If unused parameters ",
"in the model do not change over iterations, You can try to use ",
"_set_static_graph() as a workaround if this module graph does not ",
"change during training loop.");
REDUCER_CHECK(!has_marked_unused_parameters_, logger_, common_error);
}
}
void Reducer::mark_variable_ready(size_t variable_index) {
REDUCER_CHECK(
variable_index < variable_locators_.size(),
logger_,
"Out of range variable index.");
checkAndRaiseMarkedTwiceError(variable_index);
perIterationReadyParams_.insert(variable_index);
backward_stats_[variable_index] =
current_time_in_nanos() - backward_compute_start_time_;
// Any time we mark a variable ready (be it in line due to unused parameters,
// or via an autograd hook), we require a call to the finalize function. If
// this doesn't happen before the next iteration (or call to
// `prepare_for_backwards`), we know something is wrong.
require_finalize_ = true;
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
set_divide_factor();
if (bucket.expect_sparse_gradient) {
mark_variable_ready_sparse(variable_index);
} else {
mark_variable_ready_dense(variable_index);
}
// TODO(@pietern): Make this work for both CPU/CUDA tensors.
// When using CPU tensors we don't need to do this.
// Record event so that we can wait for all of them.
// auto& event = bucket.events[bucket_index.intra_bucket_index];
// event.record();
// Check if this was the final gradient for this bucket.
if (--bucket.pending == 0) {
mark_bucket_ready(bucket_index.bucket_index);
}
// Run finalizer function and kick off reduction for local_used_map once the
// final bucket was marked ready.
if (next_bucket_ == buckets_.size()) {
if (dynamic_graph_find_unused()) {
all_reduce_local_used_map();
}
torch::autograd::Engine::get_default_engine().queue_callback([this] {
std::lock_guard<std::mutex> lock(this->mutex_);
if (should_collect_runtime_stats()) {
record_backward_compute_end_time();
}
// Check that all buckets were completed and had their work kicked off.
TORCH_INTERNAL_ASSERT(next_bucket_ == buckets_.size());
if (static_graph_after_first_iteration() && should_rebuild_buckets()) {
for (const auto& unused_index : unused_parameters_) {
push_rebuilt_params(unused_index);
}
}
this->finalize_backward();
});
}
}
c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_comm_hook(
GradBucket& grad_bucket) {
if (comm_hook_ == nullptr) {
return run_allreduce_hook(grad_bucket);
} else {
return comm_hook_->runHook(grad_bucket);
}
}
c10::intrusive_ptr<c10::ivalue::Future> Reducer::run_allreduce_hook(
GradBucket& grad_bucket) {
_AllReduceBySumCommHook allreduce_hook(process_group_);
return allreduce_hook.runHook(grad_bucket);
}
void Reducer::all_reduce_bucket(Bucket& bucket) {
auto variables_for_bucket = get_variables_for_bucket(next_bucket_, bucket);
// TODO(@pietern): Ensure proper synchronization with the CUDA events
// that recorded copies into this `gradients` tensor. If these copies are
// executed on non-default streams, the current stream for the device
// that holds the `gradients` tensor must wait on these events.
//
// As long as autograd uses the default stream for every device,
// these operations are implicitly sequenced, and we don't need to
// do any extra synchronization here.
const auto& tensor = bucket.gradients;
// TODO(@egienvalue): remove special case after view ops are fully
// supported on MTIA.
// If the bucket.gradients is on MTIA, bucket.bucket_views_in might not
// point to the same storage as bucket.gradients due to the special
// memory layout. It has to explicitly copy the data back to 1-D gradients.
if (tensor.is_mtia()) {
for (const auto i : c10::irange(bucket.variables.size())) {
const auto offset = bucket.offsets[i];
const auto length = bucket.lengths[i];
if (!bucket.bucket_views_in[i].is_alias_of(tensor)) {
tensor
.narrow(
0, static_cast<int64_t>(offset), static_cast<int64_t>(length))
.copy_(bucket.bucket_views_in[i].flatten());
}
}
}
GradBucket grad_bucket(
next_bucket_,
buckets_.size(),
tensor,
bucket.offsets,
bucket.lengths,
bucket.sizes_vec,
variables_for_bucket,
bucket.sparse_tensor_indices);
bucket.future_work = run_comm_hook(grad_bucket);
}
std::vector<at::Tensor> Reducer::get_variables_for_bucket(
size_t bucket_index,
const Bucket& bucket) const {
// Check if we have cached mapping previously.
if (has_rebuilt_bucket_ &&
cached_variables_for_bucket_.find(bucket_index) !=
cached_variables_for_bucket_.end()) {