Skip to content

Commit 5ce2567

Browse files
Manjunath Kudlurgunan
Manjunath Kudlur
authored andcommitted
C++ API: Added a Const constructor for non-empty const supporting type cast.
Fixes tensorflow#3752 Change: 130113000
1 parent 67734a1 commit 5ce2567

File tree

3 files changed

+78
-13
lines changed

3 files changed

+78
-13
lines changed

tensorflow/cc/framework/cc_ops_test.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,49 @@ TEST(CCOpTest, ColocateWith) {
226226
EXPECT_TRUE(attrs.find("_class") == attrs.end());
227227
}
228228

229+
TEST(CCOpTest, TemplatedConst) {
230+
Scope root = Scope::NewRootScope();
231+
auto c1 = ops::Const<float>(root, {{3, 2}, {-1, 0}});
232+
TF_EXPECT_OK(root.status());
233+
234+
Tensor out;
235+
GetTensor(root, c1, &out);
236+
test::ExpectTensorEqual<float>(
237+
out, test::AsTensor<float>({3.f, 2.f, -1.f, 0.f}, {2, 2}));
238+
239+
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
240+
GetTensor(root, c2, &out);
241+
test::ExpectTensorEqual<string>(
242+
out, test::AsTensor<string>({"this", "is", "a", "constant"}, {4, 1}));
243+
}
244+
245+
TEST(CCOpTest, EmptyConst) {
246+
Scope root = Scope::NewRootScope();
247+
248+
auto c1 = ops::Const(root, {});
249+
TF_CHECK_OK(root.status());
250+
251+
Tensor out;
252+
GetTensor(root, c1, &out);
253+
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {0}));
254+
255+
auto c2 = ops::Const(root, {{}});
256+
TF_CHECK_OK(root.status());
257+
GetTensor(root, c2, &out);
258+
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {1, 0}));
259+
260+
auto c3 = ops::Const(root, {{{}, {}}});
261+
TF_CHECK_OK(root.status());
262+
GetTensor(root, c3, &out);
263+
test::ExpectTensorEqual<float>(out, Tensor(DT_FLOAT, {1, 2, 0}));
264+
265+
auto c4 = ops::Const<int>(root, {{{}}});
266+
TF_CHECK_OK(root.status());
267+
GetTensor(root, c4, &out);
268+
test::ExpectTensorEqual<int>(out, Tensor(DT_INT32, {1, 1, 0}));
269+
270+
ops::Const(root, {{}, {{}}});
271+
EXPECT_FALSE(root.status().ok());
272+
}
273+
229274
} // namespace tensorflow

tensorflow/cc/ops/const_op.h

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,35 @@ namespace ops {
2525

2626
Output Const(const Scope& scope, const Input::Initializer& val);
2727

28+
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
29+
2830
template <typename T>
2931
Output Const(const Scope& scope, const Input::Initializer& val) {
32+
auto orig_const_output = Const(scope, val);
3033
if (!scope.ok()) return Output();
31-
if (!val.status.ok()) {
32-
scope.UpdateStatus(val.status);
33-
return Output();
34-
}
34+
3535
typedef typename Input::Initializer::RealType<T>::type DstT;
36-
if (val.tensor.NumElements() > 0) {
37-
// TODO(keveman): Implement the in-situ cast.
38-
scope.UpdateStatus(errors::Unimplemented(
39-
"Explict cast of a non-empty tensor not implemented yet"));
40-
return Output();
36+
37+
if (val.tensor.dtype() == DataTypeToEnum<DstT>::v()) {
38+
return orig_const_output;
4139
}
42-
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
43-
return Const(scope, Input::Initializer(t));
40+
if (val.tensor.NumElements() == 0) {
41+
Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape());
42+
return Const(scope, Input::Initializer(t));
43+
}
44+
45+
// TODO(keveman): Refactor Cast op's kernel implementation such that the code
46+
// can be directly called here instead of adding the Cast op to the graph.
47+
auto orig_const = AsNodeOut(scope, orig_const_output);
48+
const auto cast_op_name = scope.GetUniqueNameForOp("Cast");
49+
50+
auto cast_builder = NodeBuilder(cast_op_name, "Cast")
51+
.Input(orig_const)
52+
.Attr("DstT", DataTypeToEnum<DstT>::v());
53+
scope.UpdateBuilder(&cast_builder);
54+
Node* ret;
55+
scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret));
56+
return Output(ret, 0);
4457
}
4558

4659
template <typename T>
@@ -54,8 +67,6 @@ Output Const(const Scope& scope, const std::initializer_list<T>& v,
5467
return Const(scope, Input::Initializer(v, shape));
5568
}
5669

57-
NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp);
58-
5970
std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope,
6071
const InputList& inp);
6172

tensorflow/cc/ops/const_op_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,13 @@ TEST(ConstOpTest, Names) {
125125
EXPECT_EQ(c_y_1.node()->name(), "c/y_1");
126126
}
127127

128+
TEST(ConstOpTest, TemplatedConst) {
129+
Scope root = Scope::NewRootScope();
130+
auto c1 = ops::Const<int>(root, {1, 2});
131+
ExpectTypeAndShape(c1.node(), DT_INT32, {2});
132+
133+
auto c2 = ops::Const<string>(root, {{"this"}, {"is"}, {"a"}, {"constant"}});
134+
ExpectTypeAndShape(c2.node(), DT_STRING, {4, 1});
135+
}
136+
128137
} // namespace tensorflow

0 commit comments

Comments
 (0)