Skip to content

Commit f3405c2

Browse files
Change nccl_manager to use ncclCommInitAll.
Change: 148169806
1 parent 4c26c32 commit f3405c2

File tree

6 files changed

+46
-52
lines changed

6 files changed

+46
-52
lines changed

tensorflow/contrib/nccl/kernels/nccl_manager.cc

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,14 @@ ncclDataType_t ToNcclType(DataType t) {
9292
struct NcclManager::Participant {
9393
Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr,
9494
perftools::gputools::Stream* tensor_stream,
95-
perftools::gputools::StreamExecutor* executor,
95+
perftools::gputools::StreamExecutor* executor, int gpu_device_id,
9696
NcclManager::DoneCallback done_callback)
9797
: in_t(in_t),
9898
out_t(out_t),
9999
event_mgr(event_mgr),
100100
tensor_stream(tensor_stream),
101101
executor(executor),
102+
gpu_device_id(gpu_device_id),
102103
done_callback(std::move(done_callback)) {
103104
DCHECK(executor != nullptr);
104105
DCHECK(event_mgr != nullptr);
@@ -120,7 +121,9 @@ struct NcclManager::Participant {
120121

121122
// Matches the executor in CommunicatorMember::stream. Expected to be live for
122123
// process lifetime.
123-
perftools::gputools::StreamExecutor* executor = nullptr;
124+
perftools::gputools::StreamExecutor* const executor = nullptr;
125+
126+
const int gpu_device_id;
124127

125128
NcclManager::DoneCallback done_callback;
126129

@@ -222,6 +225,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
222225
// Note that this is done under the lock; performance is not expected to
223226
// matter as this happens a very small number of times.
224227
std::vector<CommunicatorMember> members(num_devices);
228+
std::vector<int> devices(num_devices);
225229
for (int i = 0; i < num_devices; ++i) {
226230
auto* executor = collective->participants[i]->executor;
227231

@@ -249,30 +253,14 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
249253
}
250254

251255
members[i].nccl_stream = nccl_stream;
256+
devices[i] = collective->participants[i]->gpu_device_id;
252257
}
253258

254-
// Call ncclCommInitRank for each member.
255-
ncclUniqueId id;
256-
CHECK_EQ(ncclSuccess, ncclGetUniqueId(&id));
257-
std::unique_ptr<thread::ThreadPool> pool(
258-
new thread::ThreadPool(env, "ncclCommInitRank", num_devices));
259-
std::vector<ncclResult_t> results(num_devices);
259+
std::vector<ncclComm_t> nccl_comms(num_devices);
260+
auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
261+
CHECK_EQ(result, ncclSuccess);
260262
for (int rank = 0; rank < num_devices; ++rank) {
261-
CommunicatorMember* member = &members[rank];
262-
ncclResult_t* result = &results[rank];
263-
pool->Schedule([member, num_devices, result, rank, &id]() {
264-
ScopedActivateExecutorContext scoped_context(
265-
member->nccl_stream->executor);
266-
LOG(INFO) << "Calling ncclCommInitRank for rank " << rank;
267-
*result = ncclCommInitRank(&member->nccl_comm, num_devices, id, rank);
268-
LOG(INFO) << "Done calling ncclCommInitRank for rank " << rank << " : "
269-
<< *result;
270-
});
271-
}
272-
273-
pool.reset(); // wait for completion.
274-
for (int i = 0; i < num_devices; ++i) {
275-
CHECK_EQ(results[i], ncclSuccess);
263+
members[rank].nccl_comm = nccl_comms[rank];
276264
}
277265
communicators_.emplace_back(new Communicator(std::move(members)));
278266
return communicators_.back().get();
@@ -281,37 +269,38 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
281269
void NcclManager::AddToAllReduce(int num_devices, const string& key,
282270
ncclRedOp_t reduction_op,
283271
perftools::gputools::StreamExecutor* executor,
284-
EventMgr* event_mgr,
272+
int gpu_device_id, EventMgr* event_mgr,
285273
perftools::gputools::Stream* tensor_stream,
286274
const Tensor* in_t, Tensor* out_t,
287275
const DoneCallback& done_callback) {
288-
std::unique_ptr<Participant> participant(new Participant(
289-
in_t, out_t, event_mgr, tensor_stream, executor, done_callback));
276+
std::unique_ptr<Participant> participant(
277+
new Participant(in_t, out_t, event_mgr, tensor_stream, executor,
278+
gpu_device_id, done_callback));
290279
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
291280
kAllReduce, reduction_op);
292281
}
293282

294283
void NcclManager::AddBroadcastSend(
295284
int num_devices, const string& key,
296-
perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
297-
perftools::gputools::Stream* tensor_stream, const Tensor* in_t,
298-
DoneCallback done_callback) {
285+
perftools::gputools::StreamExecutor* executor, int gpu_device_id,
286+
EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream,
287+
const Tensor* in_t, DoneCallback done_callback) {
299288
std::unique_ptr<Participant> participant(
300289
new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream,
301-
executor, done_callback));
290+
executor, gpu_device_id, done_callback));
302291
participant->root = true;
303292
AddParticipant(num_devices, key, std::move(participant), in_t->dtype(),
304293
kBroadcast, ncclSum /* unused */);
305294
}
306295

307296
void NcclManager::AddBroadcastRecv(
308297
int num_devices, const string& key,
309-
perftools::gputools::StreamExecutor* executor, EventMgr* event_mgr,
310-
perftools::gputools::Stream* tensor_stream, Tensor* out_t,
311-
DoneCallback done_callback) {
298+
perftools::gputools::StreamExecutor* executor, int gpu_device_id,
299+
EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream,
300+
Tensor* out_t, DoneCallback done_callback) {
312301
std::unique_ptr<Participant> participant(
313302
new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream,
314-
executor, done_callback));
303+
executor, gpu_device_id, done_callback));
315304
AddParticipant(num_devices, key, std::move(participant), out_t->dtype(),
316305
kBroadcast, ncclSum /* unused */);
317306
}
@@ -331,7 +320,7 @@ void NcclManager::AddParticipant(int num_devices, const string& key,
331320
}
332321
Collective* collective = collective_ptr.get();
333322
DCHECK_EQ(collective->type, collective_type);
334-
DCHECK_EQ(collective->participants.size(), num_devices);
323+
DCHECK_LT(collective->participants.size(), num_devices);
335324
collective->participants.emplace_back(std::move(participant));
336325
++collective->available_participants;
337326

tensorflow/contrib/nccl/kernels/nccl_manager.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class NcclManager {
5757
void AddToAllReduce(int num_devices, const string& key,
5858
ncclRedOp_t reduction_op,
5959
perftools::gputools::StreamExecutor* executor,
60-
EventMgr* event_mgr,
60+
int gpu_device_id, EventMgr* event_mgr,
6161
perftools::gputools::Stream* tensor_stream,
6262
const Tensor* in_t, Tensor* out_t,
6363
const DoneCallback& done_callback);
@@ -66,12 +66,12 @@ class NcclManager {
6666
// to all receivers.
6767
void AddBroadcastSend(int num_devices, const string& key,
6868
perftools::gputools::StreamExecutor* executor,
69-
EventMgr* event_mgr,
69+
int gpu_device_id, EventMgr* event_mgr,
7070
perftools::gputools::Stream* tensor_stream,
7171
const Tensor* in_t, DoneCallback done_callback);
7272
void AddBroadcastRecv(int num_devices, const string& key,
7373
perftools::gputools::StreamExecutor* executor,
74-
EventMgr* event_mgr,
74+
int gpu_device_id, EventMgr* event_mgr,
7575
perftools::gputools::Stream* tensor_stream,
7676
Tensor* out_t, DoneCallback done_callback);
7777

tensorflow/contrib/nccl/kernels/nccl_manager_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ TEST_F(NcclManagerTest, BasicSumReduction) {
193193
auto* event_mgr = device->tensorflow_gpu_device_info()->event_mgr;
194194
auto* stream = device->tensorflow_gpu_device_info()->stream;
195195
NcclManager::instance()->AddToAllReduce(
196-
num_ranks, "allreduce", reduction_op, device->executor(), event_mgr,
197-
stream, &test_case->ins[device_num], &test_case->outs[device_num],
198-
CreateDoneCallback(test_case.get()));
196+
num_ranks, "allreduce", reduction_op, device->executor(),
197+
device->gpu_id(), event_mgr, stream, &test_case->ins[device_num],
198+
&test_case->outs[device_num], CreateDoneCallback(test_case.get()));
199199
}
200200

201201
LOG(ERROR) << "Verifying results";
@@ -259,8 +259,9 @@ TEST_F(NcclManagerTest, MultipleCallers) {
259259
TestCase* test_case = test_cases[test_num].get();
260260
NcclManager::instance()->AddToAllReduce(
261261
num_ranks, strings::StrCat("allreduce", test_num), ncclSum,
262-
device->executor(), event_mgr, stream, &test_case->ins[device_num],
263-
&test_case->outs[device_num], CreateDoneCallback(test_case));
262+
device->executor(), device->gpu_id(), event_mgr, stream,
263+
&test_case->ins[device_num], &test_case->outs[device_num],
264+
CreateDoneCallback(test_case));
264265
};
265266
pool->Schedule(fn);
266267
}

tensorflow/contrib/nccl/kernels/nccl_ops.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ class NcclAllReduceOpKernel : public NcclAsyncOpBase {
9090
};
9191

9292
auto* compute_stream = c->op_device_context()->stream();
93-
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
93+
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
9494
NcclManager::instance()->AddToAllReduce(
9595
num_devices(), GetCollectiveKey(c), reduction_op_,
96-
compute_stream->parent(), event_mgr, compute_stream, in_t, out_t,
97-
actual_done);
96+
compute_stream->parent(), gpu_info->gpu_id, gpu_info->event_mgr,
97+
compute_stream, in_t, out_t, actual_done);
9898
}
9999

100100
private:
@@ -115,10 +115,11 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase {
115115
};
116116

117117
auto* compute_stream = c->op_device_context()->stream();
118-
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
118+
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
119119
NcclManager::instance()->AddBroadcastSend(
120-
num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
121-
compute_stream, &c->input(0), std::move(actual_done));
120+
num_devices(), GetCollectiveKey(c), compute_stream->parent(),
121+
gpu_info->gpu_id, gpu_info->event_mgr, compute_stream, &c->input(0),
122+
std::move(actual_done));
122123
}
123124
};
124125
REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU),
@@ -142,10 +143,11 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase {
142143
};
143144

144145
auto* compute_stream = c->op_device_context()->stream();
145-
EventMgr* event_mgr = c->device()->tensorflow_gpu_device_info()->event_mgr;
146+
auto* gpu_info = c->device()->tensorflow_gpu_device_info();
146147
NcclManager::instance()->AddBroadcastRecv(
147-
num_devices(), GetCollectiveKey(c), compute_stream->parent(), event_mgr,
148-
compute_stream, out_t, std::move(actual_done));
148+
num_devices(), GetCollectiveKey(c), compute_stream->parent(),
149+
gpu_info->gpu_id, gpu_info->event_mgr, compute_stream, out_t,
150+
std::move(actual_done));
149151
}
150152
};
151153
REGISTER_KERNEL_BUILDER(

tensorflow/core/common_runtime/gpu/gpu_device.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
257257
gpu_device_info_->stream = streams_[0].compute;
258258
gpu_device_info_->default_context = device_contexts_[0];
259259
gpu_device_info_->event_mgr = em_.get();
260+
gpu_device_info_->gpu_id = gpu_id_;
260261
set_tensorflow_gpu_device_info(gpu_device_info_);
261262

262263
return Status::OK();

tensorflow/core/framework/device_base.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ class DeviceBase {
132132
perftools::gputools::Stream* stream = nullptr;
133133
DeviceContext* default_context = nullptr;
134134
EventMgr* event_mgr = nullptr;
135+
int gpu_id = -1;
135136
};
136137

137138
// Does not take ownership.

0 commit comments

Comments
 (0)