Skip to content

Commit dc19800

Browse files
committed
Merge pull request tensorflow#2023 from caisq/r0.8-tensorforest-2
R0.8 tensorforest cherry-pick
2 parents ac3c683 + f7ec1ed commit dc19800

34 files changed

+1162
-444
lines changed

tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc

Lines changed: 109 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818
// only op that involves tree traversal, and is constructed so that it can
1919
// be run in parallel on separate batches of data.
2020
#include <unordered_map>
21+
#include <vector>
2122

2223
#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
2324

2425
#include "tensorflow/core/framework/op.h"
2526
#include "tensorflow/core/framework/op_kernel.h"
2627

2728
#include "tensorflow/core/lib/gtl/map_util.h"
29+
#include "tensorflow/core/util/work_sharder.h"
2830

2931
namespace tensorflow {
3032

3133
using std::get;
34+
using std::make_pair;
3235
using std::make_tuple;
3336
using std::pair;
3437
using std::tuple;
@@ -42,6 +45,71 @@ using tensorforest::DecideNode;
4245
using tensorforest::Initialize;
4346
using tensorforest::IsAllInitialized;
4447

48+
// A data structure to store the results of parallel tree traversal.
49+
struct InputDataResult {
50+
// A list of each node that was visited.
51+
std::vector<int32> node_indices;
52+
// The accumulator of the leaf that a data point ended up at, or -1 if none.
53+
int32 leaf_accumulator;
54+
// The left-branch taken candidate splits.
55+
std::vector<int32> split_adds;
56+
// If the candidate splits for the leaf that a data point arrived at
57+
// were initialized or not, which determines if we add this to total
58+
// pcw counts or not.
59+
bool splits_initialized;
60+
};
61+
62+
void Evaluate(const Tensor& input_data, const Tensor& input_labels,
63+
const Tensor& tree_tensor, const Tensor& tree_thresholds,
64+
const Tensor& node_to_accumulator,
65+
const Tensor& candidate_split_features,
66+
const Tensor& candidate_split_thresholds,
67+
InputDataResult* results, int64 start, int64 end) {
68+
const auto tree = tree_tensor.tensor<int32, 2>();
69+
const auto thresholds = tree_thresholds.unaligned_flat<float>();
70+
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
71+
const auto split_features = candidate_split_features.tensor<int32, 2>();
72+
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
73+
74+
const int32 num_splits = candidate_split_features.shape().dim_size(1);
75+
76+
for (int i = start; i < end; ++i) {
77+
const Tensor point = input_data.Slice(i, i + 1);
78+
int node_index = 0;
79+
results[i].splits_initialized = false;
80+
while (true) {
81+
results[i].node_indices.push_back(node_index);
82+
int32 left_child = tree(node_index, CHILDREN_INDEX);
83+
if (left_child == LEAF_NODE) {
84+
const int32 accumulator = node_map(node_index);
85+
results[i].leaf_accumulator = accumulator;
86+
// If the leaf is not fertile or is not yet initialized, we don't
87+
// count it in the candidate/total split per-class-weights because
88+
// it won't have any candidate splits yet.
89+
if (accumulator >= 0 &&
90+
IsAllInitialized(candidate_split_features.Slice(
91+
accumulator, accumulator + 1))) {
92+
results[i].splits_initialized = true;
93+
for (int split = 0; split < num_splits; split++) {
94+
if (!DecideNode(point, split_features(accumulator, split),
95+
split_thresholds(accumulator, split))) {
96+
results[i].split_adds.push_back(split);
97+
}
98+
}
99+
}
100+
break;
101+
} else if (left_child == FREE_NODE) {
102+
LOG(ERROR) << "Reached a free node, not good.";
103+
results[i].node_indices.push_back(FREE_NODE);
104+
break;
105+
}
106+
node_index =
107+
left_child + DecideNode(point, tree(node_index, FEATURE_INDEX),
108+
thresholds(node_index));
109+
}
110+
}
111+
}
112+
45113
REGISTER_OP("CountExtremelyRandomStats")
46114
.Attr("num_classes: int32")
47115
.Input("input_data: float")
@@ -79,9 +147,9 @@ REGISTER_OP("CountExtremelyRandomStats")
79147
gives the j-th feature of the i-th input.
80148
input_labels: The training batch's labels; `input_labels[i]` is the class
81149
of the i-th input.
82-
tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
83-
of the i-th node, `tree[0][i] + 1` gives the index of the right child of
84-
the i-th node, and `tree[1][i]` gives the index of the feature used to
150+
tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
151+
of the i-th node, `tree[i][0] + 1` gives the index of the right child of
152+
the i-th node, and `tree[i][1]` gives the index of the feature used to
85153
split the i-th node.
86154
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
87155
node.
@@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel {
176244
"candidate_split_features and candidate_split_thresholds should be "
177245
"the same shape."));
178246

179-
const int32 num_splits = candidate_split_features.shape().dim_size(1);
247+
// Evaluate input data in parallel.
248+
const int64 num_data = input_data.shape().dim_size(0);
249+
std::unique_ptr<InputDataResult[]> results(new InputDataResult[num_data]);
250+
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
251+
int num_threads = worker_threads->num_threads;
252+
if (num_threads <= 1) {
253+
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
254+
node_to_accumulator, candidate_split_features,
255+
candidate_split_thresholds, results.get(), 0, num_data);
256+
} else {
257+
auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds,
258+
&node_to_accumulator, &candidate_split_features,
259+
&candidate_split_thresholds, &num_data,
260+
&results](int64 start, int64 end) {
261+
CHECK(start <= end);
262+
CHECK(end <= num_data);
263+
Evaluate(input_data, input_labels, tree_tensor, tree_thresholds,
264+
node_to_accumulator, candidate_split_features,
265+
candidate_split_thresholds, results.get(), start, end);
266+
};
267+
Shard(num_threads, worker_threads->workers, num_data, 100, work);
268+
}
269+
270+
// Set output tensors.
271+
const auto labels = input_labels.unaligned_flat<int32>();
180272

181273
// node pcw delta
182274
Tensor* output_node_pcw_delta = nullptr;
@@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel {
196288
&output_leaves));
197289
auto out_leaves = output_leaves->unaligned_flat<int32>();
198290

199-
const auto tree = tree_tensor.tensor<int32, 2>();
200-
const auto thresholds = tree_thresholds.unaligned_flat<float>();
201-
const auto labels = input_labels.unaligned_flat<int32>();
202-
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
203-
const auto split_features = candidate_split_features.tensor<int32, 2>();
204-
const auto split_thresholds = candidate_split_thresholds.tensor<float, 2>();
205-
206-
const int32 num_data = input_data.shape().dim_size(0);
207-
208291
// <accumulator, class> -> count delta
209292
std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
210293
// <accumulator, split, class> -> count delta
211294
std::unordered_map<tuple<int32, int32, int32>,
212295
int32, TupleIntHash> split_delta;
213-
for (int i = 0; i < num_data; i++) {
214-
const Tensor point = input_data.Slice(i, i+1);
215-
int node_index = 0;
216-
while (true) {
217-
const int32 label = labels(i);
218-
++out_node(node_index, label);
219-
int32 left_child = tree(node_index, CHILDREN_INDEX);
220-
if (left_child == LEAF_NODE) {
221-
out_leaves(i) = node_index;
222-
const int32 accumulator = node_map(node_index);
223-
// If the leaf is not fertile or is not yet initialized, we don't
224-
// count it in the candidate/total split per-class-weights because
225-
// it won't have any candidate splits yet.
226-
if (accumulator >= 0 &&
227-
IsAllInitialized(
228-
candidate_split_features.Slice(accumulator,
229-
accumulator + 1))) {
230-
++total_delta[std::make_pair(accumulator, label)];
231-
for (int split = 0; split < num_splits; split++) {
232-
if (!DecideNode(point, split_features(accumulator, split),
233-
split_thresholds(accumulator, split))) {
234-
++split_delta[make_tuple(accumulator, split, label)];
235-
}
236-
}
237-
}
238-
break;
239-
} else if (left_child == FREE_NODE) {
240-
LOG(ERROR) << "Reached a free node, not good.";
241-
out_leaves(i) = FREE_NODE;
242-
break;
296+
297+
for (int32 i = 0; i < num_data; ++i) {
298+
const int32 label = labels(i);
299+
const int32 accumulator = results[i].leaf_accumulator;
300+
for (const int32 node : results[i].node_indices) {
301+
++out_node(node, label);
302+
}
303+
out_leaves(i) = results[i].node_indices.back();
304+
if (accumulator >= 0 && results[i].splits_initialized) {
305+
++total_delta[make_pair(accumulator, label)];
306+
for (const int32 split : results[i].split_adds) {
307+
++split_delta[make_tuple(accumulator, split, label)];
243308
}
244-
node_index = left_child +
245-
DecideNode(point, tree(node_index, FEATURE_INDEX),
246-
thresholds(node_index));
247309
}
248310
}
249311

250-
// candidate splits pcw indices
312+
// candidate splits pcw indices
251313
Tensor* output_candidate_pcw_indices = nullptr;
252314
TensorShape candidate_pcw_shape;
253315
candidate_pcw_shape.AddDim(split_delta.size());

tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class SampleInputs : public OpKernel {
9494
"split_sampling_random_seed", &split_sampling_random_seed_));
9595
// Set up the random number generator.
9696
if (split_sampling_random_seed_ == 0) {
97-
uint64 time_seed = static_cast<uint64>(std::time(NULL));
97+
uint64 time_seed = static_cast<uint64>(std::clock());
9898
single_rand_ = std::unique_ptr<random::PhiloxRandom>(
9999
new random::PhiloxRandom(time_seed));
100100
} else {

tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ REGISTER_OP("TreePredictions")
4444
4545
input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
4646
gives the j-th feature of the i-th input.
47-
tree:= A 2-d int32 tensor. `tree[0][i]` gives the index of the left child
48-
of the i-th node, `tree[0][i] + 1` gives the index of the right child of
49-
the i-th node, and `tree[1][i]` gives the index of the feature used to
47+
tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child
48+
of the i-th node, `tree[i][0] + 1` gives the index of the right child of
49+
the i-th node, and `tree[i][1]` gives the index of the feature used to
5050
split the i-th node.
5151
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
5252
node.

tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
import tensorflow # pylint: disable=unused-import
20+
import tensorflow as tf
2121

2222
from tensorflow.contrib.tensor_forest.python.ops import training_ops
2323

@@ -47,6 +47,29 @@ def testSimple(self):
4747
self.tree_thresholds, self.node_map,
4848
self.split_features, self.split_thresholds, num_classes=4))
4949

50+
self.assertAllEqual(
51+
[[1., 1., 1., 1.], [1., 1., 0., 0.], [0., 0., 1., 1.]],
52+
pcw_node.eval())
53+
self.assertAllEqual([[0, 0, 0]], pcw_splits_indices.eval())
54+
self.assertAllEqual([1.], pcw_splits_delta.eval())
55+
self.assertAllEqual([[0, 1], [0, 0]], pcw_totals_indices.eval())
56+
self.assertAllEqual([1., 1.], pcw_totals_delta.eval())
57+
self.assertAllEqual([1, 1, 2, 2], leaves.eval())
58+
59+
def testThreaded(self):
60+
with self.test_session(
61+
config=tf.ConfigProto(intra_op_parallelism_threads=2)):
62+
(pcw_node, pcw_splits_indices, pcw_splits_delta, pcw_totals_indices,
63+
pcw_totals_delta,
64+
leaves) = (self.ops.count_extremely_random_stats(self.input_data,
65+
self.input_labels,
66+
self.tree,
67+
self.tree_thresholds,
68+
self.node_map,
69+
self.split_features,
70+
self.split_thresholds,
71+
num_classes=4))
72+
5073
self.assertAllEqual([[1., 1., 1., 1.], [1., 1., 0., 0.],
5174
[0., 0., 1., 1.]],
5275
pcw_node.eval())

tensorflow/contrib/tensor_forest/python/ops/inference_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,13 @@ def TreePredictions(op):
4949
# there's not yet any guarantee that the shared object exists.
5050
# In which case, "import tensorflow" will always crash, even for users that
5151
# never use contrib.
52-
def Load():
52+
def Load(library_base_dir=''):
5353
"""Load the inference ops library and return the loaded module."""
5454
with _ops_lock:
5555
global _inference_ops
5656
if not _inference_ops:
57-
data_files_path = tf.resource_loader.get_data_files_path()
57+
data_files_path = os.path.join(library_base_dir,
58+
tf.resource_loader.get_data_files_path())
5859
tf.logging.info('data path: %s', data_files_path)
5960
_inference_ops = tf.load_op_library(os.path.join(
6061
data_files_path, INFERENCE_OPS_FILE))

tensorflow/contrib/tensor_forest/python/ops/training_ops.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow.python.framework import ops
2626
from tensorflow.python.framework import tensor_shape
2727

28+
2829
TRAINING_OPS_FILE = '_training_ops.so'
2930

3031
_training_ops = None
@@ -96,12 +97,13 @@ def _UpdateFertileSlotsShape(unused_op):
9697
# there's not yet any guarantee that the shared object exists.
9798
# In which case, "import tensorflow" will always crash, even for users that
9899
# never use contrib.
99-
def Load():
100+
def Load(library_base_dir=''):
100101
"""Load training ops library and return the loaded module."""
101102
with _ops_lock:
102103
global _training_ops
103104
if not _training_ops:
104-
data_files_path = tf.resource_loader.get_data_files_path()
105+
data_files_path = os.path.join(library_base_dir,
106+
tf.resource_loader.get_data_files_path())
105107
tf.logging.info('data path: %s', data_files_path)
106108
_training_ops = tf.load_op_library(os.path.join(
107109
data_files_path, TRAINING_OPS_FILE))

0 commit comments

Comments
 (0)