18
18
// only op that involves tree traversal, and is constructed so that it can
19
19
// be run in parallel on separate batches of data.
20
20
#include < unordered_map>
21
+ #include < vector>
21
22
22
23
#include " tensorflow/contrib/tensor_forest/core/ops/tree_utils.h"
23
24
24
25
#include " tensorflow/core/framework/op.h"
25
26
#include " tensorflow/core/framework/op_kernel.h"
26
27
27
28
#include " tensorflow/core/lib/gtl/map_util.h"
29
+ #include " tensorflow/core/util/work_sharder.h"
28
30
29
31
namespace tensorflow {
30
32
31
33
using std::get;
34
+ using std::make_pair;
32
35
using std::make_tuple;
33
36
using std::pair;
34
37
using std::tuple;
@@ -42,6 +45,71 @@ using tensorforest::DecideNode;
42
45
using tensorforest::Initialize;
43
46
using tensorforest::IsAllInitialized;
44
47
48
+ // A data structure to store the results of parallel tree traversal.
49
+ struct InputDataResult {
50
+ // A list of each node that was visited.
51
+ std::vector<int32> node_indices;
52
+ // The accumulator of the leaf that a data point ended up at, or -1 if none.
53
+ int32 leaf_accumulator;
54
+ // The left-branch taken candidate splits.
55
+ std::vector<int32> split_adds;
56
+ // If the candidate splits for the leaf that a data point arrived at
57
+ // were initialized or not, which determines if we add this to total
58
+ // pcw counts or not.
59
+ bool splits_initialized;
60
+ };
61
+
62
+ void Evaluate (const Tensor& input_data, const Tensor& input_labels,
63
+ const Tensor& tree_tensor, const Tensor& tree_thresholds,
64
+ const Tensor& node_to_accumulator,
65
+ const Tensor& candidate_split_features,
66
+ const Tensor& candidate_split_thresholds,
67
+ InputDataResult* results, int64 start, int64 end) {
68
+ const auto tree = tree_tensor.tensor <int32, 2 >();
69
+ const auto thresholds = tree_thresholds.unaligned_flat <float >();
70
+ const auto node_map = node_to_accumulator.unaligned_flat <int32>();
71
+ const auto split_features = candidate_split_features.tensor <int32, 2 >();
72
+ const auto split_thresholds = candidate_split_thresholds.tensor <float , 2 >();
73
+
74
+ const int32 num_splits = candidate_split_features.shape ().dim_size (1 );
75
+
76
+ for (int i = start; i < end; ++i) {
77
+ const Tensor point = input_data.Slice (i, i + 1 );
78
+ int node_index = 0 ;
79
+ results[i].splits_initialized = false ;
80
+ while (true ) {
81
+ results[i].node_indices .push_back (node_index);
82
+ int32 left_child = tree (node_index, CHILDREN_INDEX);
83
+ if (left_child == LEAF_NODE) {
84
+ const int32 accumulator = node_map (node_index);
85
+ results[i].leaf_accumulator = accumulator;
86
+ // If the leaf is not fertile or is not yet initialized, we don't
87
+ // count it in the candidate/total split per-class-weights because
88
+ // it won't have any candidate splits yet.
89
+ if (accumulator >= 0 &&
90
+ IsAllInitialized (candidate_split_features.Slice (
91
+ accumulator, accumulator + 1 ))) {
92
+ results[i].splits_initialized = true ;
93
+ for (int split = 0 ; split < num_splits; split++) {
94
+ if (!DecideNode (point, split_features (accumulator, split),
95
+ split_thresholds (accumulator, split))) {
96
+ results[i].split_adds .push_back (split);
97
+ }
98
+ }
99
+ }
100
+ break ;
101
+ } else if (left_child == FREE_NODE) {
102
+ LOG (ERROR) << " Reached a free node, not good." ;
103
+ results[i].node_indices .push_back (FREE_NODE);
104
+ break ;
105
+ }
106
+ node_index =
107
+ left_child + DecideNode (point, tree (node_index, FEATURE_INDEX),
108
+ thresholds (node_index));
109
+ }
110
+ }
111
+ }
112
+
45
113
REGISTER_OP (" CountExtremelyRandomStats" )
46
114
.Attr(" num_classes: int32" )
47
115
.Input(" input_data: float" )
@@ -79,9 +147,9 @@ REGISTER_OP("CountExtremelyRandomStats")
79
147
gives the j-th feature of the i-th input.
80
148
input_labels: The training batch's labels; `input_labels[i]` is the class
81
149
of the i-th input.
82
- tree:= A 2-d int32 tensor. `tree[0][i ]` gives the index of the left child
83
- of the i-th node, `tree[0][i ] + 1` gives the index of the right child of
84
- the i-th node, and `tree[1][i ]` gives the index of the feature used to
150
+ tree:= A 2-d int32 tensor. `tree[i][0 ]` gives the index of the left child
151
+ of the i-th node, `tree[i][0 ] + 1` gives the index of the right child of
152
+ the i-th node, and `tree[i][1 ]` gives the index of the feature used to
85
153
split the i-th node.
86
154
tree_thresholds: `tree_thresholds[i]` is the value used to split the i-th
87
155
node.
@@ -176,7 +244,31 @@ class CountExtremelyRandomStats : public OpKernel {
176
244
" candidate_split_features and candidate_split_thresholds should be "
177
245
" the same shape." ));
178
246
179
- const int32 num_splits = candidate_split_features.shape ().dim_size (1 );
247
+ // Evaluate input data in parallel.
248
+ const int64 num_data = input_data.shape ().dim_size (0 );
249
+ std::unique_ptr<InputDataResult[]> results (new InputDataResult[num_data]);
250
+ auto worker_threads = context->device ()->tensorflow_cpu_worker_threads ();
251
+ int num_threads = worker_threads->num_threads ;
252
+ if (num_threads <= 1 ) {
253
+ Evaluate (input_data, input_labels, tree_tensor, tree_thresholds,
254
+ node_to_accumulator, candidate_split_features,
255
+ candidate_split_thresholds, results.get (), 0 , num_data);
256
+ } else {
257
+ auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds,
258
+ &node_to_accumulator, &candidate_split_features,
259
+ &candidate_split_thresholds, &num_data,
260
+ &results](int64 start, int64 end) {
261
+ CHECK (start <= end);
262
+ CHECK (end <= num_data);
263
+ Evaluate (input_data, input_labels, tree_tensor, tree_thresholds,
264
+ node_to_accumulator, candidate_split_features,
265
+ candidate_split_thresholds, results.get (), start, end);
266
+ };
267
+ Shard (num_threads, worker_threads->workers , num_data, 100 , work);
268
+ }
269
+
270
+ // Set output tensors.
271
+ const auto labels = input_labels.unaligned_flat <int32>();
180
272
181
273
// node pcw delta
182
274
Tensor* output_node_pcw_delta = nullptr ;
@@ -196,58 +288,28 @@ class CountExtremelyRandomStats : public OpKernel {
196
288
&output_leaves));
197
289
auto out_leaves = output_leaves->unaligned_flat <int32>();
198
290
199
- const auto tree = tree_tensor.tensor <int32, 2 >();
200
- const auto thresholds = tree_thresholds.unaligned_flat <float >();
201
- const auto labels = input_labels.unaligned_flat <int32>();
202
- const auto node_map = node_to_accumulator.unaligned_flat <int32>();
203
- const auto split_features = candidate_split_features.tensor <int32, 2 >();
204
- const auto split_thresholds = candidate_split_thresholds.tensor <float , 2 >();
205
-
206
- const int32 num_data = input_data.shape ().dim_size (0 );
207
-
208
291
// <accumulator, class> -> count delta
209
292
std::unordered_map<pair<int32, int32>, int32, PairIntHash> total_delta;
210
293
// <accumulator, split, class> -> count delta
211
294
std::unordered_map<tuple<int32, int32, int32>,
212
295
int32, TupleIntHash> split_delta;
213
- for (int i = 0 ; i < num_data; i++) {
214
- const Tensor point = input_data.Slice (i, i+1 );
215
- int node_index = 0 ;
216
- while (true ) {
217
- const int32 label = labels (i);
218
- ++out_node (node_index, label);
219
- int32 left_child = tree (node_index, CHILDREN_INDEX);
220
- if (left_child == LEAF_NODE) {
221
- out_leaves (i) = node_index;
222
- const int32 accumulator = node_map (node_index);
223
- // If the leaf is not fertile or is not yet initialized, we don't
224
- // count it in the candidate/total split per-class-weights because
225
- // it won't have any candidate splits yet.
226
- if (accumulator >= 0 &&
227
- IsAllInitialized (
228
- candidate_split_features.Slice (accumulator,
229
- accumulator + 1 ))) {
230
- ++total_delta[std::make_pair (accumulator, label)];
231
- for (int split = 0 ; split < num_splits; split++) {
232
- if (!DecideNode (point, split_features (accumulator, split),
233
- split_thresholds (accumulator, split))) {
234
- ++split_delta[make_tuple (accumulator, split, label)];
235
- }
236
- }
237
- }
238
- break ;
239
- } else if (left_child == FREE_NODE) {
240
- LOG (ERROR) << " Reached a free node, not good." ;
241
- out_leaves (i) = FREE_NODE;
242
- break ;
296
+
297
+ for (int32 i = 0 ; i < num_data; ++i) {
298
+ const int32 label = labels (i);
299
+ const int32 accumulator = results[i].leaf_accumulator ;
300
+ for (const int32 node : results[i].node_indices ) {
301
+ ++out_node (node, label);
302
+ }
303
+ out_leaves (i) = results[i].node_indices .back ();
304
+ if (accumulator >= 0 && results[i].splits_initialized ) {
305
+ ++total_delta[make_pair (accumulator, label)];
306
+ for (const int32 split : results[i].split_adds ) {
307
+ ++split_delta[make_tuple (accumulator, split, label)];
243
308
}
244
- node_index = left_child +
245
- DecideNode (point, tree (node_index, FEATURE_INDEX),
246
- thresholds (node_index));
247
309
}
248
310
}
249
311
250
- // candidate splits pcw indices
312
+ // candidate splits pcw indices
251
313
Tensor* output_candidate_pcw_indices = nullptr ;
252
314
TensorShape candidate_pcw_shape;
253
315
candidate_pcw_shape.AddDim (split_delta.size ());
0 commit comments