Skip to content

Commit 400a398

Browse files
yongtangdrpngx
authored andcommitted
Require same shape for x and y in shape function of ApproximateEqual (tensorflow#19878)
* Require same shape for `x` and `y` in shape function of `ApproximateEqual` In the kernel implementation of `ApproximateEqual` the shape of inputs `x` and `y` should be the same. Though in the shape function of `ApproximateEqual` there was no such validation. This fix adds the shape validation in the shape function to make sure `x` and `y` are of the same shape, if they are known. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add test case for shape function of ApproximateEqual Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
1 parent 5fa7b03 commit 400a398

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tensorflow/core/ops/math_ops.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,13 @@ REGISTER_OP("ApproximateEqual")
592592
.SetIsCommutative()
593593
.Attr("T: numbertype")
594594
.Attr("tolerance: float = 0.00001")
595-
.SetShapeFn(shape_inference::UnchangedShape);
595+
.SetShapeFn([](InferenceContext* c) {
596+
// The inputs 'x' and 'y' must have the same shape.
597+
ShapeHandle data_x = c->input(0);
598+
ShapeHandle data_y = c->input(1);
599+
TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
600+
return shape_inference::UnchangedShape(c);
601+
});
596602

597603
// --------------------------------------------------------------------------
598604

tensorflow/python/ops/math_ops_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,15 @@ def testApproximateEqual(self):
235235
z_tf = self.evaluate(math_ops.approximate_equal(x, y, tolerance=0.0001))
236236
self.assertAllEqual(z, z_tf)
237237

238+
def testApproximateEqualShape(self):
239+
for dtype in [np.float32, np.double]:
240+
x = np.array([1, 2], dtype=np.float32)
241+
y = np.array([[1, 2]], dtype=np.float32)
242+
# The inputs 'x' and 'y' must have the same shape.
243+
with self.assertRaisesRegexp(
244+
ValueError, "Shapes must be equal rank, but are 1 and 2"):
245+
math_ops.approximate_equal(x, y)
246+
238247

239248
class ScalarMulTest(test_util.TensorFlowTestCase):
240249

0 commit comments

Comments
 (0)