Skip to content

Commit 9d8201c

Browse files
langmoretensorflower-gardener
authored andcommitted
BUGFIX: sample_stats.percentile: Fractions for interpolation were done in
float, which caused the error d - 1 == d for large 'd'. Fix is to do in double. Also clipping index values to [0,..., d - 1] in case double isn't enough for some huge array PiperOrigin-RevId: 191023164
1 parent 8712a7f commit 9d8201c

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorflow.contrib.distributions.python.ops import sample_stats
2424
from tensorflow.python.framework import dtypes
2525
from tensorflow.python.ops import array_ops
26+
from tensorflow.python.ops import math_ops
2627
from tensorflow.python.ops import spectral_ops_test_util
2728
from tensorflow.python.platform import test
2829

@@ -455,6 +456,16 @@ def test_vector_q_raises_dynamic(self):
455456
with self.assertRaisesOpError("rank"):
456457
pct.eval(feed_dict={q_ph: [0.5]})
457458

459+
def test_finds_max_of_long_array(self):
460+
# d - 1 == d in float32 and d = 3e7.
461+
# So this test only passes if we use double for the percentile indices.
462+
# If float is used, it fails with InvalidArgumentError about an index out of
463+
# bounds.
464+
x = math_ops.linspace(0., 3e7, num=int(3e7))
465+
with self.test_session():
466+
minval = sample_stats.percentile(x, q=0, validate_args=True)
467+
self.assertAllEqual(0, minval.eval())
468+
458469

459470
if __name__ == "__main__":
460471
test.main()

tensorflow/contrib/distributions/python/ops/sample_stats.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorflow.python.framework import tensor_util
2626
from tensorflow.python.ops import array_ops
2727
from tensorflow.python.ops import check_ops
28+
from tensorflow.python.ops import clip_ops
2829
from tensorflow.python.ops import control_flow_ops
2930
from tensorflow.python.ops import math_ops
3031
from tensorflow.python.ops import nn_ops
@@ -301,13 +302,16 @@ def percentile(x,
301302

302303
with ops.name_scope(name, [x, q]):
303304
x = ops.convert_to_tensor(x, name="x")
304-
q = math_ops.to_float(q, name="q")
305+
# Double is needed here and below, else we get the wrong index if the array
306+
# is huge along axis.
307+
q = math_ops.to_double(q, name="q")
305308
_get_static_ndims(q, expect_ndims=0)
306309

307310
if validate_args:
308311
q = control_flow_ops.with_dependencies([
309-
check_ops.assert_rank(q, 0), check_ops.assert_greater_equal(q, 0.),
310-
check_ops.assert_less_equal(q, 100.)
312+
check_ops.assert_rank(q, 0),
313+
check_ops.assert_greater_equal(q, math_ops.to_double(0.)),
314+
check_ops.assert_less_equal(q, math_ops.to_double(100.))
311315
], q)
312316

313317
if axis is None:
@@ -332,7 +336,7 @@ def percentile(x,
332336
y = _move_dims_to_flat_end(x, axis, x_ndims)
333337

334338
frac_at_q_or_above = 1. - q / 100.
335-
d = math_ops.to_float(array_ops.shape(y)[-1])
339+
d = math_ops.to_double(array_ops.shape(y)[-1])
336340

337341
if interpolation == "lower":
338342
index = math_ops.ceil((d - 1) * frac_at_q_or_above)
@@ -341,12 +345,18 @@ def percentile(x,
341345
elif interpolation == "nearest":
342346
index = math_ops.round((d - 1) * frac_at_q_or_above)
343347

348+
# If d is gigantic, then we would have d == d - 1, even in double... So
349+
# let's use max/min to avoid out of bounds errors.
350+
d = array_ops.shape(y)[-1]
351+
# d - 1 will be distinct from d in int32.
352+
index = clip_ops.clip_by_value(math_ops.to_int32(index), 0, d - 1)
353+
344354
# Sort everything, not just the top 'k' entries, which allows multiple calls
345355
# to sort only once (under the hood) and use CSE.
346356
sorted_y = _sort_tensor(y)
347357

348358
# result.shape = B
349-
result = sorted_y[..., math_ops.to_int32(index)]
359+
result = sorted_y[..., index]
350360
result.set_shape(y.get_shape()[:-1])
351361

352362
if keep_dims:

0 commit comments

Comments
 (0)