25
25
from tensorflow .python .framework import tensor_util
26
26
from tensorflow .python .ops import array_ops
27
27
from tensorflow .python .ops import check_ops
28
+ from tensorflow .python .ops import clip_ops
28
29
from tensorflow .python .ops import control_flow_ops
29
30
from tensorflow .python .ops import math_ops
30
31
from tensorflow .python .ops import nn_ops
@@ -301,13 +302,16 @@ def percentile(x,
301
302
302
303
with ops .name_scope (name , [x , q ]):
303
304
x = ops .convert_to_tensor (x , name = "x" )
304
- q = math_ops .to_float (q , name = "q" )
305
+ # Double is needed here and below, else we get the wrong index if the array
306
+ # is huge along axis.
307
+ q = math_ops .to_double (q , name = "q" )
305
308
_get_static_ndims (q , expect_ndims = 0 )
306
309
307
310
if validate_args :
308
311
q = control_flow_ops .with_dependencies ([
309
- check_ops .assert_rank (q , 0 ), check_ops .assert_greater_equal (q , 0. ),
310
- check_ops .assert_less_equal (q , 100. )
312
+ check_ops .assert_rank (q , 0 ),
313
+ check_ops .assert_greater_equal (q , math_ops .to_double (0. )),
314
+ check_ops .assert_less_equal (q , math_ops .to_double (100. ))
311
315
], q )
312
316
313
317
if axis is None :
@@ -332,7 +336,7 @@ def percentile(x,
332
336
y = _move_dims_to_flat_end (x , axis , x_ndims )
333
337
334
338
frac_at_q_or_above = 1. - q / 100.
335
- d = math_ops .to_float (array_ops .shape (y )[- 1 ])
339
+ d = math_ops .to_double (array_ops .shape (y )[- 1 ])
336
340
337
341
if interpolation == "lower" :
338
342
index = math_ops .ceil ((d - 1 ) * frac_at_q_or_above )
@@ -341,12 +345,18 @@ def percentile(x,
341
345
elif interpolation == "nearest" :
342
346
index = math_ops .round ((d - 1 ) * frac_at_q_or_above )
343
347
348
+ # If d is gigantic, then we would have d == d - 1, even in double... So
349
+ # let's use max/min to avoid out of bounds errors.
350
+ d = array_ops .shape (y )[- 1 ]
351
+ # d - 1 will be distinct from d in int32.
352
+ index = clip_ops .clip_by_value (math_ops .to_int32 (index ), 0 , d - 1 )
353
+
344
354
# Sort everything, not just the top 'k' entries, which allows multiple calls
345
355
# to sort only once (under the hood) and use CSE.
346
356
sorted_y = _sort_tensor (y )
347
357
348
358
# result.shape = B
349
- result = sorted_y [..., math_ops . to_int32 ( index ) ]
359
+ result = sorted_y [..., index ]
350
360
result .set_shape (y .get_shape ()[:- 1 ])
351
361
352
362
if keep_dims :
0 commit comments