Skip to content

Commit 083c4dc

Browse files
tensorflower-gardenergunan
authored andcommitted
Change StridedSlice to error on scalar input, in both
the shape inference function and the kernel. Change: 134434589
1 parent 425f49e commit 083c4dc

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ REGISTER_OP("CudnnRNNParamsSize")
7272
.Attr(kRNNInputModeAttrs)
7373
.Attr(kRNNDirectionAttrs)
7474
.Output("params_size: S")
75-
.SetShapeFn(shape_inference::ScalarShape)
75+
.SetShapeFn([](InferenceContext* c) {
76+
c->set_output(0, c->Vector(1));
77+
return Status::OK();
78+
})
7679
.Doc(strings::StrCat(R"doc(
7780
Return the params size that can be used by the Cudnn RNN model. Subsequent
7881
weight allocation and initialization should use this size.

tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace tensorflow {
2626

2727
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
2828
ShapeInferenceTestOp op("CudnnRNNParamsSize");
29-
INFER_OK(op, "[1];[1];[1]", "[]");
29+
INFER_OK(op, "[1];[1];[1]", "[1]");
3030
}
3131

3232
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {

tensorflow/core/util/strided_slice_op.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ struct StridedSliceDenseSpec {
9494
} // namespace
9595

9696
template <class T>
97-
static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
98-
StridedSliceDenseSpec* dense) {
97+
static Status TF_MUST_USE_RESULT BuildDenseSpec(
98+
const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
9999
// Build expanded begin, end, strides, begin_mask, end_mask
100100
// to remove any ellipsis
101101
dense->begin.resize(dense->dims);
@@ -130,6 +130,12 @@ static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
130130
} else if ((1 << i) & sparse.new_axis_mask) {
131131
dense->final_shape_gather_indices.push_back(kNewAxis);
132132
} else {
133+
if (full_index == dense->begin.size()) {
134+
return errors::InvalidArgument("Index out of range using input dim ",
135+
full_index, "; input has only ",
136+
dense->dims, " dims");
137+
}
138+
133139
// Gather slicing spec into appropriate index
134140
dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
135141
dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
@@ -154,6 +160,7 @@ static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
154160
}
155161
}
156162
}
163+
return Status::OK();
157164
}
158165

159166
Status ValidateStridedSliceOp(
@@ -233,9 +240,9 @@ Status ValidateStridedSliceOp(
233240
input_shape.dims(), 0, 0, *begin, *end, *strides};
234241

235242
if (begin_tensor.dtype() == DT_INT32) {
236-
BuildDenseSpec<int32>(sparse_spec, &dense_spec);
243+
TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
237244
} else if (begin_tensor.dtype() == DT_INT64) {
238-
BuildDenseSpec<int64>(sparse_spec, &dense_spec);
245+
TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
239246
} else {
240247
LOG(FATAL) << "begin must be either int32 or int64";
241248
}

tensorflow/python/kernel_tests/slice_op_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,34 @@ def testSingleDimension(self):
7777
slice_val = slice_t.eval()
7878
self.assertAllEqual(slice_val, inp[lo:hi])
7979

80+
def testScalarInput(self):
81+
input_val = 0
82+
with self.test_session() as sess:
83+
# Test with constant input; shape inference fails.
84+
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
85+
tf.constant(input_val)[:].get_shape()
86+
87+
# Test evaluating with non-constant input; kernel execution fails.
88+
input_t = tf.placeholder(tf.int32)
89+
slice_t = input_t[:]
90+
with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
91+
"out of range"):
92+
sess.run([slice_t], feed_dict={input_t: input_val})
93+
94+
def testInvalidIndex(self):
95+
input_val = [1, 2]
96+
with self.test_session() as sess:
97+
# Test with constant input; shape inference fails.
98+
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
99+
tf.constant(input_val)[1:, 1:].get_shape()
100+
101+
# Test evaluating with non-constant input; kernel execution fails.
102+
input_t = tf.placeholder(tf.int32)
103+
slice_t = input_t[1:, 1:]
104+
with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
105+
"out of range"):
106+
sess.run([slice_t], feed_dict={input_t: input_val})
107+
80108
def _testSliceMatrixDim0(self, x, begin, size):
81109
with self.test_session(use_gpu=True):
82110
tf_ans = tf.slice(x, [begin, 0], [size, x.shape[1]]).eval()

0 commit comments

Comments
 (0)