Skip to content

Commit 1e2a17a

Browse files
mrrymartinwicke
authored andcommitted
Fix bugs in the indexing operator with latest NumPy (tensorflow#1895)
* Made the indexing operator accept objects that can be converted with `int()`. Previously, the `_SliceHelper()` would test for `isinstance(..., int)`, which is too restrictive. * Replace `np.random.random_integers()` with `np.random.randint()`
1 parent bce8c55 commit 1e2a17a

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

tensorflow/python/kernel_tests/slice_op_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@ def _testSingleDimension(self, use_gpu):
7575
inp = np.random.rand(10).astype("f")
7676
a = tf.constant(inp, shape=[10], dtype=tf.float32)
7777

78-
hi = np.random.random_integers(0, 9)
78+
hi = np.random.randint(0, 9)
7979
scalar_t = a[hi]
8080
scalar_val = scalar_t.eval()
8181
self.assertAllEqual(scalar_val, inp[hi])
8282

83-
lo = np.random.random_integers(0, hi)
83+
if hi > 0:
84+
lo = np.random.randint(0, hi)
85+
else:
86+
lo = 0
8487
slice_t = a[lo:hi]
8588
slice_val = slice_t.eval()
8689
self.assertAllEqual(slice_val, inp[lo:hi])
@@ -110,7 +113,7 @@ def _testIndexAndSlice(self, use_gpu):
110113
inp = np.random.rand(4, 4).astype("f")
111114
a = tf.constant(inp, shape=[4, 4], dtype=tf.float32)
112115

113-
x, y = np.random.random_integers(0, 3, size=2).tolist()
116+
x, y = np.random.randint(0, 3, size=2).tolist()
114117
slice_t = a[x, 0:y]
115118
slice_val = slice_t.eval()
116119
self.assertAllEqual(slice_val, inp[x, 0:y])
@@ -142,9 +145,12 @@ def _testComplex(self, use_gpu):
142145
inp = np.random.rand(4, 10, 10, 4).astype("f")
143146
a = tf.constant(inp, dtype=tf.float32)
144147

145-
x = np.random.random_integers(0, 9)
146-
z = np.random.random_integers(0, 9)
147-
y = np.random.random_integers(0, z)
148+
x = np.random.randint(0, 9)
149+
z = np.random.randint(0, 9)
150+
if z > 0:
151+
y = np.random.randint(0, z)
152+
else:
153+
y = 0
148154
slice_t = a[:, x, y:z, :]
149155
self.assertAllEqual(slice_t.eval(), inp[:, x, y:z, :])
150156

tensorflow/python/ops/array_ops.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,7 @@ def _SliceHelper(tensor, slice_spec):
130130
sizes = []
131131
squeeze_dims = []
132132
for dim, s in enumerate(slice_spec):
133-
if isinstance(s, int):
134-
if s < 0:
135-
raise NotImplementedError("Negative indices are currently unsupported")
136-
indices.append(s)
137-
sizes.append(1)
138-
squeeze_dims.append(dim)
139-
elif isinstance(s, _baseslice):
133+
if isinstance(s, _baseslice):
140134
if s.step not in (None, 1):
141135
raise NotImplementedError(
142136
"Steps other than 1 are not currently supported")
@@ -161,7 +155,15 @@ def _SliceHelper(tensor, slice_spec):
161155
elif s is Ellipsis:
162156
raise NotImplementedError("Ellipsis is not currently supported")
163157
else:
164-
raise TypeError("Bad slice index %s of type %s" % (s, type(s)))
158+
try:
159+
s = int(s)
160+
except TypeError:
161+
raise TypeError("Bad slice index %s of type %s" % (s, type(s)))
162+
if s < 0:
163+
raise NotImplementedError("Negative indices are currently unsupported")
164+
indices.append(s)
165+
sizes.append(1)
166+
squeeze_dims.append(dim)
165167
sliced = slice(tensor, indices, sizes)
166168
if squeeze_dims:
167169
return squeeze(sliced, squeeze_dims=squeeze_dims)

0 commit comments

Comments
 (0)