Skip to content

Commit 36c2d9a

Browse files
rohan100jaintensorflower-gardener
authored andcommitted
Getting rid of the threadpool from FunctionBufferingResource. It wasn't really serving much purpose apart from moving all the function logic execution onto those threads and making the destruction of the resource quite complicated.
PiperOrigin-RevId: 191017836
1 parent ac39aec commit 36c2d9a

File tree

14 files changed

+92
-57
lines changed

14 files changed

+92
-57
lines changed

tensorflow/contrib/data/kernels/prefetching_kernels.cc

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ class FunctionBufferingResource : public ResourceBase {
4040
const NameAttrList& func, int64 buffer_size,
4141
const string& source_device,
4242
const string& target_device,
43-
const std::vector<Tensor>& func_args,
44-
int64 thread_pool_size)
43+
const std::vector<Tensor>& func_args)
4544
: lib_(lib),
4645
pflr_(std::move(pflr)),
4746
func_(func),
@@ -52,22 +51,10 @@ class FunctionBufferingResource : public ResourceBase {
5251
handle_(kInvalidHandle),
5352
is_buffering_(false),
5453
end_of_sequence_(false),
55-
cancelled_(false) {
56-
if (thread_pool_size > 0) {
57-
thread_pool_ = new thread::ThreadPool(Env::Default(), ThreadOptions(),
58-
"buffer_resource", thread_pool_size,
59-
false /* low_latency_hint */);
60-
runner_ = [this](std::function<void()> c) {
61-
thread_pool_->Schedule(std::move(c));
62-
};
63-
}
64-
}
54+
cancelled_(false) {}
6555

6656
~FunctionBufferingResource() override {
6757
Cancel();
68-
if (thread_pool_ != nullptr) {
69-
delete thread_pool_;
70-
}
7158
}
7259

7360
string DebugString() override {
@@ -179,17 +166,12 @@ class FunctionBufferingResource : public ResourceBase {
179166
for (int i = 0; i < cancellation_callbacks.size(); ++i) {
180167
cancellation_callbacks[i](cancellation_buffer_elements[i]);
181168
}
182-
// We only wait on cond_var_ in the destructor, so there would atmost be
183-
// one waiter to notify.
184-
cond_var_.notify_one();
169+
cond_var_.notify_all();
185170
return;
186171
}
187172
FunctionLibraryRuntime::Options opts;
188173
// Copied from CapturedFunction::generate_step_id();
189174
opts.step_id = -std::abs(static_cast<int64>(random::New64()));
190-
if (runner_ != nullptr) {
191-
opts.runner = &runner_;
192-
}
193175
opts.source_device = source_device_;
194176
AllocatorAttributes arg_alloc_attr;
195177
arg_alloc_attr.set_on_host(true);
@@ -251,11 +233,9 @@ class FunctionBufferingResource : public ResourceBase {
251233
const string source_device_;
252234
const string target_device_;
253235
const std::vector<Tensor> func_args_;
254-
thread::ThreadPool* thread_pool_ = nullptr;
255236
FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
256237
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
257238
std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
258-
std::function<void(std::function<void()>)> runner_ = nullptr;
259239
bool is_buffering_ GUARDED_BY(mu_);
260240
bool end_of_sequence_ GUARDED_BY(mu_);
261241
bool cancelled_ GUARDED_BY(mu_);
@@ -270,7 +250,6 @@ class FunctionBufferResourceHandleOp : public OpKernel {
270250
OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
271251
OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
272252
OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
273-
OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_));
274253
}
275254

276255
~FunctionBufferResourceHandleOp() override {
@@ -318,7 +297,7 @@ class FunctionBufferResourceHandleOp : public OpKernel {
318297
this](FunctionBufferingResource** ptr) {
319298
*ptr = new FunctionBufferingResource(
320299
clone_lib, std::move(pflr), func_, buffer_size_,
321-
source_device, target_device, func_args, thread_pool_size_);
300+
source_device, target_device, func_args);
322301
return Status::OK();
323302
}));
324303
core::ScopedUnref s(buffer);
@@ -340,7 +319,6 @@ class FunctionBufferResourceHandleOp : public OpKernel {
340319
int64 buffer_size_;
341320
string container_;
342321
string name_;
343-
int64 thread_pool_size_;
344322
};
345323

346324
REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")

tensorflow/contrib/data/ops/dataset_ops.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ REGISTER_OP("FunctionBufferingResource")
5353
.Attr("container: string")
5454
.Attr("f: func")
5555
.Attr("buffer_size: int")
56-
.Attr("thread_pool_size: int")
5756
.SetShapeFn(shape_inference::UnknownShape)
5857
.Doc(R"doc(
5958
Creates a resource that fills up a buffer by making function calls.
@@ -63,7 +62,6 @@ target_device: Target device to execute the function on.
6362
resource: Handle to the resource created.
6463
f: Function to be executed.
6564
buffer_size: Size of the buffer.
66-
thread_pool_size: Size of the threadpool doing the prefetching.
6765
container: If non-empty, this resource is placed in the given container.
6866
Otherwise, a default container is used.
6967
shared_name: If non-empty, this resource will be shared under the given name

tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def _remote_fn(h):
7070
target_device=target,
7171
string_arg=ds_iterator_handle,
7272
buffer_size=3,
73-
thread_pool_size=2,
7473
shared_name=buffer_name)
7574

7675
with ops.device(device1):

tensorflow/contrib/data/python/ops/prefetching_ops.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def function_buffering_resource(string_arg,
3636
target_device,
3737
f,
3838
buffer_size,
39-
thread_pool_size=0,
4039
container="",
4140
shared_name=None,
4241
name=None):
@@ -48,7 +47,6 @@ def function_buffering_resource(string_arg,
4847
shared_name=shared_name,
4948
f=f,
5049
buffer_size=buffer_size,
51-
thread_pool_size=thread_pool_size,
5250
container=container,
5351
name=name)
5452

@@ -90,8 +88,7 @@ def _prefetch_fn(handle):
9088
target_device=gen_dataset_ops.iterator_get_device(
9189
input_iterator._iterator_resource),
9290
string_arg=input_iterator_handle,
93-
buffer_size=buffer_size,
94-
thread_pool_size=0)
91+
buffer_size=buffer_size)
9592

9693
def get_next(self, name=None):
9794
"""See @{tf.data.Iterator.get_next}."""

tensorflow/contrib/distribute/python/prefetching_ops_v2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def _prefetch_fn(handle):
5959
f=_prefetch_fn,
6060
target_device=target_device,
6161
string_arg=input_iterator_handle,
62-
buffer_size=buffer_size,
63-
thread_pool_size=0)
62+
buffer_size=buffer_size)
6463
self._buffering_resources.append(buffer_resource_handle)
6564

6665
def get_next(self, name=None):

tensorflow/contrib/eager/python/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ def remote_fn(h):
9898
f=remote_fn,
9999
target_device=target,
100100
buffer_size=10,
101-
thread_pool_size=1,
102101
container="",
103102
shared_name=_generate_shared_name("function_buffer_resource"))
104103
self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long

tensorflow/core/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3273,6 +3273,18 @@ tf_cc_test(
32733273
],
32743274
)
32753275

3276+
tf_cc_test(
3277+
name = "common_runtime_process_util_test",
3278+
size = "small",
3279+
srcs = ["common_runtime/process_util_test.cc"],
3280+
linkstatic = tf_kernel_tests_linkstatic(),
3281+
deps = [
3282+
":core_cpu_internal",
3283+
":test",
3284+
":test_main",
3285+
],
3286+
)
3287+
32763288
tf_cc_test(
32773289
name = "common_runtime_rendezvous_util_test",
32783290
size = "small",

tensorflow/core/common_runtime/direct_session.cc

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "tensorflow/core/common_runtime/graph_optimizer.h"
2828
#include "tensorflow/core/common_runtime/memory_types.h"
2929
#include "tensorflow/core/common_runtime/optimization_registry.h"
30+
#include "tensorflow/core/common_runtime/process_util.h"
3031
#include "tensorflow/core/common_runtime/step_stats_collector.h"
3132
#include "tensorflow/core/framework/function.h"
3233
#include "tensorflow/core/framework/graph.pb_text.h"
@@ -69,20 +70,6 @@ auto* direct_session_runs = monitoring::Counter<0>::New(
6970
"/tensorflow/core/direct_session_runs",
7071
"The number of times DirectSession::Run() has been called.");
7172

72-
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
73-
const int32 t = options.config.inter_op_parallelism_threads();
74-
if (t != 0) return t;
75-
// Default to using the number of cores available in the process.
76-
return port::NumSchedulableCPUs();
77-
}
78-
79-
thread::ThreadPool* NewThreadPoolFromSessionOptions(
80-
const SessionOptions& options) {
81-
const int32 num_threads = NumInterOpThreadsFromSessionOptions(options);
82-
VLOG(1) << "Direct session inter op parallelism threads: " << num_threads;
83-
return new thread::ThreadPool(options.env, "Compute", num_threads);
84-
}
85-
8673
Status NewThreadPoolFromThreadPoolOptions(
8774
const SessionOptions& options,
8875
const ThreadPoolOptionProto& thread_pool_options, int pool_number,

tensorflow/core/common_runtime/eager/context.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515

1616
#include "tensorflow/core/common_runtime/eager/context.h"
1717

18+
#include "tensorflow/core/common_runtime/process_util.h"
19+
1820
namespace tensorflow {
1921

2022
EagerContext::EagerContext(const SessionOptions& opts,
@@ -25,9 +27,10 @@ EagerContext::EagerContext(const SessionOptions& opts,
2527
device_manager_(std::move(device_mgr)),
2628
devices_(device_manager_->ListDevices()),
2729
rendezvous_(rendezvous),
28-
pflr_(new ProcessFunctionLibraryRuntime(device_manager_.get(), opts.env,
29-
TF_GRAPH_DEF_VERSION,
30-
&func_lib_def_, {})),
30+
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
31+
pflr_(new ProcessFunctionLibraryRuntime(
32+
device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
33+
{}, thread_pool_.get())),
3134
log_device_placement_(opts.config.log_device_placement()),
3235
async_default_(async) {
3336
if (async_default_) {

tensorflow/core/common_runtime/eager/context.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License.
3030
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
3131
#include "tensorflow/core/framework/rendezvous.h"
3232
#include "tensorflow/core/lib/core/stringpiece.h"
33+
#include "tensorflow/core/lib/core/threadpool.h"
3334
#include "tensorflow/core/lib/gtl/inlined_vector.h"
3435
#include "tensorflow/core/lib/gtl/map_util.h"
3536
#include "tensorflow/core/lib/gtl/stl_util.h"
@@ -160,6 +161,8 @@ class EagerContext {
160161
FunctionLibraryDefinition func_lib_def_ GUARDED_BY(functions_mu_){
161162
OpRegistry::Global(), {}};
162163

164+
std::unique_ptr<thread::ThreadPool> thread_pool_;
165+
163166
// One FunctionLibraryRuntime per device.
164167
// func_libs[i] is the FunctionLibraryRuntime corresponding to
165168
// session->devices[i].

tensorflow/core/common_runtime/function.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -796,16 +796,17 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
796796
done(status);
797797
};
798798
}
799-
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
800-
parent_->Run(run_opts, handle, args, rets, done);
801-
return;
802-
}
803799

804800
if (run_opts.runner == nullptr) {
805801
run_opts.runner = &default_runner_;
806802
}
807803
DCHECK(run_opts.runner != nullptr);
808804

805+
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
806+
parent_->Run(run_opts, handle, args, rets, done);
807+
return;
808+
}
809+
809810
Executor::Args* exec_args = new Executor::Args;
810811
// Inherit the step_id from the caller.
811812
exec_args->step_id = run_opts.step_id;

tensorflow/core/common_runtime/process_util.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ thread::ThreadPool* ComputePool(const SessionOptions& options) {
4646
return compute_pool;
4747
}
4848

49+
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
50+
const int32 t = options.config.inter_op_parallelism_threads();
51+
if (t != 0) return t;
52+
// Default to using the number of cores available in the process.
53+
return port::NumSchedulableCPUs();
54+
}
55+
56+
thread::ThreadPool* NewThreadPoolFromSessionOptions(
57+
const SessionOptions& options) {
58+
const int32 num_threads = NumInterOpThreadsFromSessionOptions(options);
59+
VLOG(1) << "Direct session inter op parallelism threads: " << num_threads;
60+
return new thread::ThreadPool(options.env, "Compute", num_threads);
61+
}
62+
4963
void SchedClosure(std::function<void()> closure) {
5064
if (port::Tracing::IsActive()) {
5165
const uint64 id = port::Tracing::UniqueId();

tensorflow/core/common_runtime/process_util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ namespace tensorflow {
3030
// using 'options'. Caller does not take ownership over threadpool.
3131
thread::ThreadPool* ComputePool(const SessionOptions& options);
3232

33+
// Returns number of inter op threads.
34+
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options);
35+
36+
// Creates a thread pool with number of inter op threads.
37+
thread::ThreadPool* NewThreadPoolFromSessionOptions(
38+
const SessionOptions& options);
39+
3340
// Schedule "closure" in the default thread queue.
3441
void SchedClosure(std::function<void()> closure);
3542

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/core/common_runtime/process_util.h"
16+
17+
#include "tensorflow/core/platform/test.h"
18+
19+
namespace tensorflow {
20+
namespace {
21+
22+
TEST(ProcessUtilTest, NumThreads) {
23+
SessionOptions opts;
24+
opts.config.set_inter_op_parallelism_threads(10);
25+
EXPECT_EQ(10, NumInterOpThreadsFromSessionOptions(opts));
26+
}
27+
28+
TEST(ProcessUtilTest, ThreadPool) {
29+
SessionOptions opts;
30+
opts.config.set_inter_op_parallelism_threads(10);
31+
32+
thread::ThreadPool* pool = NewThreadPoolFromSessionOptions(opts);
33+
EXPECT_EQ(10, pool->NumThreads());
34+
delete pool;
35+
}
36+
37+
} // anonymous namespace
38+
} // namespace tensorflow

0 commit comments

Comments
 (0)