Skip to content

Commit a0802b5

Browse files
tensorflower-gardenercwhipkey
authored andcommitted
Add the graphdef version to InferenceContext and to ShapeRefiner::AddNode.
Use this to allow loading reductions saved with older graphdefs. Change GraphConstructor to not increase the version when importing, but instead take the min of all versions. Change: 149152437
1 parent 29a6b46 commit a0802b5

25 files changed

+337
-168
lines changed

tensorflow/c/c_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ extern "C" {
729729
struct TF_Graph {
730730
TF_Graph()
731731
: graph(OpRegistry::Global()),
732-
refiner(graph.op_registry()),
732+
refiner(graph.versions().producer(), graph.op_registry()),
733733
num_sessions(0),
734734
delete_requested(false) {}
735735
mutex mu;

tensorflow/cc/framework/scope.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map,
3434

3535
Scope Scope::NewRootScope() {
3636
Graph* graph = new Graph(OpRegistry::Global());
37-
ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry());
37+
ShapeRefiner* refiner = new ShapeRefiner(
38+
graph->versions().producer(), graph->op_registry());
3839
return Scope(graph, new Status, new Scope::NameMap, refiner);
3940
}
4041

tensorflow/core/common_runtime/shape_refiner.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ using shape_inference::DimensionHandle;
3131
using shape_inference::InferenceContext;
3232
using shape_inference::ShapeHandle;
3333

34-
ShapeRefiner::ShapeRefiner(const OpRegistryInterface* ops)
35-
: ops_registry_(ops) {}
34+
ShapeRefiner::ShapeRefiner(int graph_def_version,
35+
const OpRegistryInterface* ops)
36+
: graph_def_version_(graph_def_version), ops_registry_(ops) {}
3637

3738
ShapeRefiner::~ShapeRefiner() { gtl::STLDeleteValues(&node_to_context_); }
3839

@@ -87,9 +88,10 @@ Status ShapeRefiner::AddNode(const Node* node) {
8788
std::vector<ShapeHandle> input_tensors_as_shapes;
8889

8990
// Create the inference context for this node with the existing input shapes.
90-
std::unique_ptr<InferenceContext> c(new InferenceContext(
91-
&node->def(), node->op_def(), input_shapes, input_tensors,
92-
input_tensors_as_shapes, input_handle_shapes, input_handle_dtypes));
91+
std::unique_ptr<InferenceContext> c(
92+
new InferenceContext(graph_def_version_, &node->def(), node->op_def(),
93+
input_shapes, input_tensors, input_tensors_as_shapes,
94+
input_handle_shapes, input_handle_dtypes));
9395
if (!c->construction_status().ok()) {
9496
return c->construction_status();
9597
}

tensorflow/core/common_runtime/shape_refiner.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace tensorflow {
3131
// construction time.
3232
class ShapeRefiner {
3333
public:
34-
explicit ShapeRefiner(const OpRegistryInterface* ops);
34+
ShapeRefiner(int graph_def_version, const OpRegistryInterface* ops);
3535
~ShapeRefiner();
3636

3737
// Performs validation of 'node' and runs 'node's shape function,
@@ -99,7 +99,8 @@ class ShapeRefiner {
9999
const Node* node, int dst_idx,
100100
shape_inference::ShapeHandle* result);
101101

102-
const OpRegistryInterface* ops_registry_ = nullptr;
102+
const int graph_def_version_;
103+
const OpRegistryInterface* const ops_registry_;
103104

104105
// Stores a map from a node to its InferenceContext.
105106
//

tensorflow/core/common_runtime/shape_refiner_test.cc

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License.
2323
#include "tensorflow/core/lib/core/status.h"
2424
#include "tensorflow/core/lib/core/status_test_util.h"
2525
#include "tensorflow/core/platform/test.h"
26+
#include "tensorflow/core/public/version.h"
2627

2728
namespace tensorflow {
2829
namespace {
@@ -38,14 +39,14 @@ TEST(ShapeRefinerTest, Constant) {
3839
// and that its shape is correct.
3940
Scope root = Scope::NewRootScope();
4041
auto c = ops::Const(root, 42.0f);
41-
ShapeRefiner m(OpRegistry::Global());
42+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
4243
TF_ASSERT_OK(m.AddNode(c.node()));
4344

4445
EXPECT_SHAPE("[]", m, c, 0);
4546
}
4647

4748
TEST(ShapeRefinerTest, MatMul) {
48-
ShapeRefiner m(OpRegistry::Global());
49+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
4950

5051
Scope root = Scope::NewRootScope();
5152
auto a = ops::Const(root, {{1.0f}, {2.0f}});
@@ -62,7 +63,7 @@ TEST(ShapeRefinerTest, MatMul) {
6263
}
6364

6465
TEST(ShapeRefinerTest, InvalidOrder) {
65-
ShapeRefiner m(OpRegistry::Global());
66+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
6667
Scope root = Scope::NewRootScope();
6768
auto a = ops::Const(root, {{1.0f}, {2.0f}});
6869
auto b = ops::Const(root, {{1.0f, 2.0f}});
@@ -77,7 +78,7 @@ TEST(ShapeRefinerTest, InvalidOrder) {
7778
}
7879

7980
TEST(ShapeRefinerTest, BadShapes) {
80-
ShapeRefiner m(OpRegistry::Global());
81+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
8182
Scope root = Scope::NewRootScope();
8283
auto a = ops::Const(root, {{1.0f}, {2.0f}});
8384
auto b = ops::Const(root, {{1.0f}, {2.0f}});
@@ -94,7 +95,7 @@ TEST(ShapeRefinerTest, BadShapes) {
9495
}
9596

9697
TEST(ShapeRefinerTest, SetShape) {
97-
ShapeRefiner m(OpRegistry::Global());
98+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
9899

99100
Scope root = Scope::NewRootScope();
100101
auto a = ops::Placeholder(root, DT_FLOAT);
@@ -136,7 +137,7 @@ TEST(ShapeRefinerTest, PropagateConstants) {
136137
auto dim = ops::Variable(root, {}, DT_INT32);
137138

138139
auto am = ops::ArgMax(root, input, dim);
139-
ShapeRefiner m(OpRegistry::Global());
140+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
140141
TF_ASSERT_OK(m.AddNode(input.node()));
141142
TF_ASSERT_OK(m.AddNode(dim.node()));
142143
TF_ASSERT_OK(m.AddNode(am.node()));
@@ -153,7 +154,7 @@ TEST(ShapeRefinerTest, PropagateConstants) {
153154
auto dim = ops::Const(root, 1);
154155

155156
auto am = ops::ArgMax(root, input, dim);
156-
ShapeRefiner m(OpRegistry::Global());
157+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
157158
TF_ASSERT_OK(m.AddNode(input.node()));
158159
TF_ASSERT_OK(m.AddNode(dim.node()));
159160
TF_ASSERT_OK(m.AddNode(am.node()));
@@ -169,7 +170,7 @@ TEST(ShapeRefinerTest, PropagateConstants) {
169170
auto dim = ops::Const(root, 0);
170171

171172
auto am = ops::ArgMax(root, input, dim);
172-
ShapeRefiner m(OpRegistry::Global());
173+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
173174
TF_ASSERT_OK(m.AddNode(input.node()));
174175
TF_ASSERT_OK(m.AddNode(dim.node()));
175176
TF_ASSERT_OK(m.AddNode(am.node()));
@@ -199,7 +200,7 @@ REGISTER_OP("TestOp")
199200
} // namespace
200201

201202
TEST(ShapeRefinerTest, InputTensorDependencies) {
202-
ShapeRefiner m(OpRegistry::Global());
203+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
203204
Graph graph(OpRegistry::Global());
204205
Node* node;
205206

@@ -260,7 +261,7 @@ TEST(ShapeRefinerTest, PropagateShape) {
260261
.Input(shape.node())
261262
.Finalize(root.graph(), &shape_data));
262263

263-
ShapeRefiner m(OpRegistry::Global());
264+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
264265
TF_ASSERT_OK(m.AddNode(input.node()));
265266
TF_ASSERT_OK(m.AddNode(shape.node()));
266267
TF_ASSERT_OK(m.AddNode(shape_data));
@@ -281,7 +282,7 @@ TEST(ShapeRefinerTest, PropagateSize) {
281282
.Input(size.node())
282283
.Finalize(root.graph(), &shape_data));
283284

284-
ShapeRefiner m(OpRegistry::Global());
285+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
285286
TF_ASSERT_OK(m.AddNode(input.node()));
286287
TF_ASSERT_OK(m.AddNode(size.node()));
287288
TF_ASSERT_OK(m.AddNode(shape_data));
@@ -302,7 +303,7 @@ TEST(ShapeRefinerTest, PropagateRank) {
302303
.Input(rank.node())
303304
.Finalize(root.graph(), &shape_data));
304305

305-
ShapeRefiner m(OpRegistry::Global());
306+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
306307
TF_ASSERT_OK(m.AddNode(input.node()));
307308
TF_ASSERT_OK(m.AddNode(rank.node()));
308309
TF_ASSERT_OK(m.AddNode(shape_data));
@@ -323,7 +324,7 @@ TEST(ShapeRefinerTest, PropagateRange) {
323324
.Input(range.node())
324325
.Finalize(root.graph(), &shape_data));
325326

326-
ShapeRefiner m(OpRegistry::Global());
327+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
327328
TF_ASSERT_OK(m.AddNode(begin.node()));
328329
TF_ASSERT_OK(m.AddNode(limit.node()));
329330
TF_ASSERT_OK(m.AddNode(delta.node()));
@@ -346,7 +347,7 @@ TEST(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) {
346347
.Input(range.node())
347348
.Finalize(root.graph(), &shape_data));
348349

349-
ShapeRefiner m(OpRegistry::Global());
350+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
350351
TF_ASSERT_OK(m.AddNode(begin_and_delta.node()));
351352
TF_ASSERT_OK(m.AddNode(limit.node()));
352353
TF_ASSERT_OK(m.AddNode(range.node()));
@@ -381,7 +382,7 @@ TEST(ShapeRefinerTest, ConstantValueVisitNodeTwice) {
381382
.Input(range.node())
382383
.Finalize(root.graph(), &shape_data));
383384

384-
ShapeRefiner m(OpRegistry::Global());
385+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
385386
TF_ASSERT_OK(m.AddNode(begin.node()));
386387
TF_ASSERT_OK(m.AddNode(limit.node()));
387388
TF_ASSERT_OK(m.AddNode(delta.node()));
@@ -477,7 +478,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) {
477478
.Input(input)
478479
.Finalize(root.graph(), &result));
479480

480-
ShapeRefiner m(OpRegistry::Global());
481+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
481482
TF_ASSERT_OK(m.AddNode(input));
482483
TF_ASSERT_OK(m.AddNode(result));
483484

@@ -498,7 +499,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) {
498499
.Input(shape.node())
499500
.Finalize(root.graph(), &result));
500501

501-
ShapeRefiner m(OpRegistry::Global());
502+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
502503
TF_ASSERT_OK(m.AddNode(input));
503504
TF_ASSERT_OK(m.AddNode(shape.node()));
504505
TF_ASSERT_OK(m.AddNode(result));
@@ -533,7 +534,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) {
533534
.Input(pack.node())
534535
.Finalize(root.graph(), &result));
535536

536-
ShapeRefiner m(OpRegistry::Global());
537+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
537538
for (auto input : inputs) {
538539
TF_ASSERT_OK(m.AddNode(input.node()));
539540
}
@@ -565,7 +566,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) {
565566
.Input(pack.node())
566567
.Finalize(root.graph(), &result));
567568

568-
ShapeRefiner m(OpRegistry::Global());
569+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
569570
for (const auto& input : inputs) {
570571
TF_ASSERT_OK(m.AddNode(input.node()));
571572
}
@@ -591,7 +592,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) {
591592
.Input(pack.node())
592593
.Finalize(root.graph(), &result));
593594

594-
ShapeRefiner m(OpRegistry::Global());
595+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
595596
for (const auto& input : inputs) {
596597
TF_ASSERT_OK(m.AddNode(input.node()));
597598
}
@@ -618,7 +619,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) {
618619
.Input(pack.node())
619620
.Finalize(root.graph(), &result));
620621

621-
ShapeRefiner m(OpRegistry::Global());
622+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
622623
for (const auto& input : inputs) {
623624
TF_ASSERT_OK(m.AddNode(input.node()));
624625
}
@@ -650,7 +651,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) {
650651
.Input(concat.node())
651652
.Finalize(g, &result));
652653

653-
ShapeRefiner m(OpRegistry::Global());
654+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
654655
TF_ASSERT_OK(m.AddNode(partial_1));
655656
TF_ASSERT_OK(m.AddNode(partial_2));
656657
for (const auto& o : concat_inputs) {
@@ -692,7 +693,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) {
692693
.Input(concat.node())
693694
.Finalize(g, &result));
694695

695-
ShapeRefiner m(OpRegistry::Global());
696+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
696697
TF_ASSERT_OK(m.AddNode(partial_1));
697698
TF_ASSERT_OK(m.AddNode(partial_2));
698699
TF_ASSERT_OK(m.AddNode(unknown));
@@ -734,7 +735,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
734735
.Input(concat.node())
735736
.Finalize(g, &result));
736737

737-
ShapeRefiner m(OpRegistry::Global());
738+
ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
738739
TF_ASSERT_OK(m.AddNode(partial_1));
739740
TF_ASSERT_OK(m.AddNode(partial_2));
740741
for (const auto& o : concat_inputs) {

tensorflow/core/framework/common_shape_fns.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,13 @@ Status ReductionShape(InferenceContext* c) {
590590
ShapeHandle input = c->input(0);
591591

592592
ShapeHandle indices;
593-
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
593+
// Older versions of TensorFlow accidentally allowed higher rank tensors like
594+
// [[1,2]] or [[1],[2]] to represent axis=[1,2].
595+
if (c->graph_def_version() < 21) {
596+
indices = c->input(1);
597+
} else {
598+
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices));
599+
}
594600

595601
bool keep_dims;
596602
TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims));

0 commit comments

Comments
 (0)