@@ -92,13 +92,14 @@ ncclDataType_t ToNcclType(DataType t) {
92
92
struct NcclManager ::Participant {
93
93
Participant (const Tensor* in_t , Tensor* out_t , EventMgr* event_mgr,
94
94
perftools::gputools::Stream* tensor_stream,
95
- perftools::gputools::StreamExecutor* executor,
95
+ perftools::gputools::StreamExecutor* executor, int gpu_device_id,
96
96
NcclManager::DoneCallback done_callback)
97
97
: in_t (in_t ),
98
98
out_t (out_t ),
99
99
event_mgr(event_mgr),
100
100
tensor_stream(tensor_stream),
101
101
executor(executor),
102
+ gpu_device_id(gpu_device_id),
102
103
done_callback(std::move(done_callback)) {
103
104
DCHECK (executor != nullptr );
104
105
DCHECK (event_mgr != nullptr );
@@ -120,7 +121,9 @@ struct NcclManager::Participant {
120
121
121
122
// Matches the executor in CommunicatorMember::stream. Expected to be live for
122
123
// process lifetime.
123
- perftools::gputools::StreamExecutor* executor = nullptr ;
124
+ perftools::gputools::StreamExecutor* const executor = nullptr ;
125
+
126
+ const int gpu_device_id;
124
127
125
128
NcclManager::DoneCallback done_callback;
126
129
@@ -222,6 +225,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
222
225
// Note that this is done under the lock; performance is not expected to
223
226
// matter as this happens a very small number of times.
224
227
std::vector<CommunicatorMember> members (num_devices);
228
+ std::vector<int > devices (num_devices);
225
229
for (int i = 0 ; i < num_devices; ++i) {
226
230
auto * executor = collective->participants [i]->executor ;
227
231
@@ -249,30 +253,14 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
249
253
}
250
254
251
255
members[i].nccl_stream = nccl_stream;
256
+ devices[i] = collective->participants [i]->gpu_device_id ;
252
257
}
253
258
254
- // Call ncclCommInitRank for each member.
255
- ncclUniqueId id;
256
- CHECK_EQ (ncclSuccess, ncclGetUniqueId (&id));
257
- std::unique_ptr<thread::ThreadPool> pool (
258
- new thread::ThreadPool (env, " ncclCommInitRank" , num_devices));
259
- std::vector<ncclResult_t> results (num_devices);
259
+ std::vector<ncclComm_t> nccl_comms (num_devices);
260
+ auto result = ncclCommInitAll (nccl_comms.data (), num_devices, devices.data ());
261
+ CHECK_EQ (result, ncclSuccess);
260
262
for (int rank = 0 ; rank < num_devices; ++rank) {
261
- CommunicatorMember* member = &members[rank];
262
- ncclResult_t* result = &results[rank];
263
- pool->Schedule ([member, num_devices, result, rank, &id]() {
264
- ScopedActivateExecutorContext scoped_context (
265
- member->nccl_stream ->executor );
266
- LOG (INFO) << " Calling ncclCommInitRank for rank " << rank;
267
- *result = ncclCommInitRank (&member->nccl_comm , num_devices, id, rank);
268
- LOG (INFO) << " Done calling ncclCommInitRank for rank " << rank << " : "
269
- << *result;
270
- });
271
- }
272
-
273
- pool.reset (); // wait for completion.
274
- for (int i = 0 ; i < num_devices; ++i) {
275
- CHECK_EQ (results[i], ncclSuccess);
263
+ members[rank].nccl_comm = nccl_comms[rank];
276
264
}
277
265
communicators_.emplace_back (new Communicator (std::move (members)));
278
266
return communicators_.back ().get ();
@@ -281,37 +269,38 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
281
269
void NcclManager::AddToAllReduce (int num_devices, const string& key,
282
270
ncclRedOp_t reduction_op,
283
271
perftools::gputools::StreamExecutor* executor,
284
- EventMgr* event_mgr,
272
+ int gpu_device_id, EventMgr* event_mgr,
285
273
perftools::gputools::Stream* tensor_stream,
286
274
const Tensor* in_t , Tensor* out_t ,
287
275
const DoneCallback& done_callback) {
288
- std::unique_ptr<Participant> participant (new Participant (
289
- in_t , out_t , event_mgr, tensor_stream, executor, done_callback));
276
+ std::unique_ptr<Participant> participant (
277
+ new Participant (in_t , out_t , event_mgr, tensor_stream, executor,
278
+ gpu_device_id, done_callback));
290
279
AddParticipant (num_devices, key, std::move (participant), in_t ->dtype (),
291
280
kAllReduce , reduction_op);
292
281
}
293
282
294
283
void NcclManager::AddBroadcastSend (
295
284
int num_devices, const string& key,
296
- perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr ,
297
- perftools::gputools::Stream* tensor_stream, const Tensor* in_t ,
298
- DoneCallback done_callback) {
285
+ perftools::gputools::StreamExecutor* executor, int gpu_device_id ,
286
+ EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream,
287
+ const Tensor* in_t , DoneCallback done_callback) {
299
288
std::unique_ptr<Participant> participant (
300
289
new Participant (in_t , nullptr /* out_t */ , event_mgr, tensor_stream,
301
- executor, done_callback));
290
+ executor, gpu_device_id, done_callback));
302
291
participant->root = true ;
303
292
AddParticipant (num_devices, key, std::move (participant), in_t ->dtype (),
304
293
kBroadcast , ncclSum /* unused */ );
305
294
}
306
295
307
296
void NcclManager::AddBroadcastRecv (
308
297
int num_devices, const string& key,
309
- perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr ,
310
- perftools::gputools::Stream* tensor_stream, Tensor* out_t ,
311
- DoneCallback done_callback) {
298
+ perftools::gputools::StreamExecutor* executor, int gpu_device_id ,
299
+ EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream,
300
+ Tensor* out_t , DoneCallback done_callback) {
312
301
std::unique_ptr<Participant> participant (
313
302
new Participant (nullptr /* in_t */ , out_t , event_mgr, tensor_stream,
314
- executor, done_callback));
303
+ executor, gpu_device_id, done_callback));
315
304
AddParticipant (num_devices, key, std::move (participant), out_t ->dtype (),
316
305
kBroadcast , ncclSum /* unused */ );
317
306
}
@@ -331,7 +320,7 @@ void NcclManager::AddParticipant(int num_devices, const string& key,
331
320
}
332
321
Collective* collective = collective_ptr.get ();
333
322
DCHECK_EQ (collective->type , collective_type);
334
- DCHECK_EQ (collective->participants .size (), num_devices);
323
+ DCHECK_LT (collective->participants .size (), num_devices);
335
324
collective->participants .emplace_back (std::move (participant));
336
325
++collective->available_participants ;
337
326
0 commit comments