Skip to content

Commit beb130a

Browse files
committed
Fix all NaN slice handling
When using the axis keyword it can happen that a slice contains only NaNs. This change corrects the logic that restores the NaNs at the end. Added tests for 2D inputs.
1 parent 89b0cf8 commit beb130a

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

doc/modules/array_api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ Estimators with support for `Array API`-compatible inputs
9191
- :class:`decomposition.PCA` (with `svd_solver="full"`,
9292
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
9393
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
94+
- :class:`preprocessing.MinMaxScaler`
9495

9596
Coverage for more estimators is expected to grow over time. Please follow the
9697
dedicated `meta-issue on GitHub

sklearn/utils/_array_api.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -495,9 +495,13 @@ def _nanmin(X, axis=None):
495495
return xp.asarray(numpy.nanmin(X, axis=axis))
496496

497497
else:
498-
X = xp.min(xp.where(~xp.isnan(X), X, xp.asarray(+xp.inf)), axis=axis)
498+
mask = xp.isnan(X)
499+
X = xp.min(xp.where(mask, xp.asarray(+xp.inf), X), axis=axis)
499500
# Replace Infs from all NaN slices with NaN again
500-
return xp.where(~xp.isinf(X), X, xp.asarray(xp.nan))
501+
mask = xp.all(mask, axis=axis)
502+
if xp.any(mask):
503+
X = xp.where(~mask, X, xp.asarray(xp.nan))
504+
return X
501505

502506

503507
def _nanmax(X, axis=None):
@@ -508,9 +512,13 @@ def _nanmax(X, axis=None):
508512
return xp.asarray(numpy.nanmax(X, axis=axis))
509513

510514
else:
511-
X = xp.max(xp.where(~xp.isnan(X), X, xp.asarray(-xp.inf)), axis=axis)
515+
mask = xp.isnan(X)
516+
X = xp.max(xp.where(mask, xp.asarray(-xp.inf), X), axis=axis)
512517
# Replace Infs from all NaN slices with NaN again
513-
return xp.where(~_isneginf(X), X, xp.asarray(xp.nan))
518+
mask = xp.all(mask, axis=axis)
519+
if xp.any(mask):
520+
X = xp.where(~mask, X, xp.asarray(xp.nan))
521+
return X
514522

515523

516524
def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None):

sklearn/utils/tests/test_array_api.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import numpy
24
import pytest
35
from numpy.testing import assert_allclose, assert_array_equal
@@ -171,17 +173,43 @@ def test_asarray_with_order_ignored():
171173
[
172174
([1, 2, numpy.nan], _nanmin, 1),
173175
([1, -2, -numpy.nan], _nanmin, -2),
176+
([numpy.inf, numpy.inf], _nanmin, numpy.inf),
177+
(
178+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
179+
partial(_nanmin, axis=0),
180+
[1.0, 2.0, 3.0],
181+
),
182+
(
183+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
184+
partial(_nanmin, axis=1),
185+
[1.0, numpy.nan, 4.0],
186+
),
174187
([1, 2, numpy.nan], _nanmax, 2),
175188
([1, 2, numpy.nan], _nanmax, 2),
189+
([-numpy.inf, -numpy.inf], _nanmax, -numpy.inf),
190+
(
191+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
192+
partial(_nanmax, axis=0),
193+
[4.0, 5.0, 6.0],
194+
),
195+
(
196+
[[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]],
197+
partial(_nanmax, axis=1),
198+
[3.0, numpy.nan, 6.0],
199+
),
176200
],
177201
)
178202
def test_nan_reductions(library, X, reduction, expected):
179203
"""Check NaN reductions like _nanmin and _nanmax"""
180204
xp = pytest.importorskip(library)
181205

206+
if isinstance(expected, list):
207+
expected = xp.asarray(expected)
208+
182209
with config_context(array_api_dispatch=True):
183210
result = reduction(xp.asarray(X))
184-
assert result == expected
211+
212+
assert_allclose(result, expected)
185213

186214

187215
@skip_if_array_api_compat_not_configured

0 commit comments

Comments
 (0)