diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 741ebbf240a6d..4f381fc916958 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -96,6 +96,7 @@ Estimators - :class:`decomposition.PCA` (with `svd_solver="full"`, `svd_solver="randomized"` and `power_iteration_normalizer="QR"`) - :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`) +- :class:`preprocessing.MinMaxScaler` Tools ----- diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index dc955f7aa0f51..d97ccea9469eb 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -398,7 +398,7 @@ Changelog - |Feature| Compute a custom out-of-bag score by passing a callable to :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor`. - :pr:`25177` by :user:`Tim Head `. + :pr:`25177` by `Tim Head`_. - |Feature| :class:`ensemble.GradientBoostingClassifier` now exposes out-of-bag scores via the `oob_scores_` or `oob_score_` attributes. diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index f048e7f955995..0164c81abdebc 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -174,6 +174,9 @@ Changelog is enabled and should be passed via the `params` parameter. :pr:`26896` by `Adrin Jalali`_. +- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports + Array API compatible inputs. :pr:`26855` by `Tim Head`_. + :mod:`sklearn.neighbors` ........................ @@ -197,8 +200,11 @@ Changelog when `sparse_output=True` and the output is configured to be pandas. :pr:`26931` by `Thomas Fan`_. -- |Enhancement| :func:`sklearn.model_selection.train_test_split` now supports - Array API compatible inputs. :pr:`26855` by `Tim Head`_. +- |MajorFeature| :class:`preprocessing.MinMaxScaler` now + supports the `Array API `_. Array API + support is considered experimental and might evolve without being subject to + our usual rolling deprecation cycle policy. See + :ref:`array_api` for more details. :pr:`26243` by `Tim Head`_. :mod:`sklearn.tree` ................... diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index b51f2af2c3d17..160e482379989 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -22,7 +22,8 @@ TransformerMixin, _fit_context, ) -from ..utils import check_array +from ..utils import _array_api, check_array +from ..utils._array_api import get_namespace from ..utils._param_validation import Interval, Options, StrOptions, validate_params from ..utils.extmath import _incremental_mean_and_var, row_norms from ..utils.sparsefuncs import ( @@ -103,16 +104,18 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): if scale == 0.0: scale = 1.0 return scale - elif isinstance(scale, np.ndarray): + # scale is an array + else: + xp, _ = get_namespace(scale) if constant_mask is None: # Detect near constant values to avoid dividing by a very small # value that could lead to surprising results and numerical # stability issues. - constant_mask = scale < 10 * np.finfo(scale.dtype).eps + constant_mask = scale < 10 * xp.finfo(scale.dtype).eps if copy: # New array to avoid side-effects - scale = scale.copy() + scale = xp.asarray(scale, copy=True) scale[constant_mask] = 1.0 return scale @@ -468,22 +471,24 @@ def partial_fit(self, X, y=None): "Consider using MaxAbsScaler instead." ) + xp, _ = get_namespace(X) + first_pass = not hasattr(self, "n_samples_seen_") X = self._validate_data( X, reset=first_pass, - dtype=FLOAT_DTYPES, + dtype=_array_api.supported_float_dtypes(xp), force_all_finite="allow-nan", ) - data_min = np.nanmin(X, axis=0) - data_max = np.nanmax(X, axis=0) + data_min = _array_api._nanmin(X, axis=0) + data_max = _array_api._nanmax(X, axis=0) if first_pass: self.n_samples_seen_ = X.shape[0] else: - data_min = np.minimum(self.data_min_, data_min) - data_max = np.maximum(self.data_max_, data_max) + data_min = xp.minimum(self.data_min_, data_min) + data_max = xp.maximum(self.data_max_, data_max) self.n_samples_seen_ += X.shape[0] data_range = data_max - data_min @@ -511,10 +516,12 @@ def transform(self, X): """ check_is_fitted(self) + xp, _ = get_namespace(X) + X = self._validate_data( X, copy=self.copy, - dtype=FLOAT_DTYPES, + dtype=_array_api.supported_float_dtypes(xp), force_all_finite="allow-nan", reset=False, ) @@ -522,7 +529,7 @@ def transform(self, X): X *= self.scale_ X += self.min_ if self.clip: - np.clip(X, self.feature_range[0], self.feature_range[1], out=X) + xp.clip(X, self.feature_range[0], self.feature_range[1], out=X) return X def inverse_transform(self, X): @@ -540,8 +547,13 @@ def inverse_transform(self, X): """ check_is_fitted(self) + xp, _ = get_namespace(X) + X = check_array( - X, copy=self.copy, dtype=FLOAT_DTYPES, force_all_finite="allow-nan" + X, + copy=self.copy, + dtype=_array_api.supported_float_dtypes(xp), + force_all_finite="allow-nan", ) X -= self.min_ diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 189d8875bc8f2..5af6172080e9e 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -41,6 +41,9 @@ from sklearn.preprocessing._data import BOUNDS_THRESHOLD, _handle_zeros_in_scale from sklearn.svm import SVR from sklearn.utils import gen_batches, shuffle +from sklearn.utils._array_api import ( + yield_namespace_device_dtype_combinations, +) from sklearn.utils._testing import ( _convert_container, assert_allclose, @@ -51,6 +54,10 @@ assert_array_less, skip_if_32bit, ) +from sklearn.utils.estimator_checks import ( + _get_check_estimator_ids, + check_array_api_input_and_values, +) from sklearn.utils.sparsefuncs import mean_variance_axis iris = datasets.load_iris() @@ -684,6 +691,26 @@ def test_standard_check_array_of_inverse_transform(): scaler.inverse_transform(x) +@pytest.mark.parametrize( + "array_namespace, device, dtype", yield_namespace_device_dtype_combinations() +) +@pytest.mark.parametrize( + "check", + [check_array_api_input_and_values], + ids=_get_check_estimator_ids, +) +@pytest.mark.parametrize( + "estimator", + [MinMaxScaler()], + ids=_get_check_estimator_ids, +) +def test_minmaxscaler_array_api_compliance( + estimator, check, array_namespace, device, dtype +): + name = estimator.__class__.__name__ + check(name, estimator, array_namespace, device=device, dtype=dtype) + + def test_min_max_scaler_iris(): X = iris.data scaler = MinMaxScaler() diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index ed16ce767a0cd..9368c131e5cd1 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -145,7 +145,7 @@ def _isdtype_single(dtype, kind, *, xp): for k in ("signed integer", "unsigned integer") ) elif kind == "real floating": - return dtype in {xp.float32, xp.float64} + return dtype in supported_float_dtypes(xp) elif kind == "complex floating": # Some name spaces do not have complex, such as cupy.array_api # and numpy.array_api @@ -166,14 +166,29 @@ def _isdtype_single(dtype, kind, *, xp): return dtype == kind +def supported_float_dtypes(xp): + """Supported floating point types for the namespace + + Note: float16 is not officially part of the Array API spec at the + time of writing but scikit-learn estimators and functions can choose + to accept it when xp.float16 is defined. + + https://data-apis.org/array-api/latest/API_specification/data_types.html + """ + if hasattr(xp, "float16"): + return (xp.float64, xp.float32, xp.float16) + else: + return (xp.float64, xp.float32) + + class _ArrayAPIWrapper: """sklearn specific Array API compatibility wrapper This wrapper makes it possible for scikit-learn maintainers to deal with discrepancies between different implementations of the - Python array API standard and its evolution over time. + Python Array API standard and its evolution over time. - The Python array API standard specification: + The Python Array API standard specification: https://data-apis.org/array-api/latest/ Documentation of the NumPy implementation: @@ -268,6 +283,9 @@ class _NumPyAPIWrapper: "uint16", "uint32", "uint64", + # XXX: float16 is not part of the Array API spec but exposed by + # some namespaces. + "float16", "float32", "float64", "complex64", @@ -393,6 +411,8 @@ def get_namespace(*arrays): namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True + # These namespaces need additional wrapping to smooth out small differences + # between implementations if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) @@ -463,6 +483,40 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): return float(xp.sum(sample_score)) +def _nanmin(X, axis=None): + # TODO: refactor once nan-aware reductions are standardized: + # https://github.com/data-apis/array-api/issues/621 + xp, _ = get_namespace(X) + if _is_numpy_namespace(xp): + return xp.asarray(numpy.nanmin(X, axis=axis)) + + else: + mask = xp.isnan(X) + X = xp.min(xp.where(mask, xp.asarray(+xp.inf), X), axis=axis) + # Replace Infs from all NaN slices with NaN again + mask = xp.all(mask, axis=axis) + if xp.any(mask): + X = xp.where(mask, xp.asarray(xp.nan), X) + return X + + +def _nanmax(X, axis=None): + # TODO: refactor once nan-aware reductions are standardized: + # https://github.com/data-apis/array-api/issues/621 + xp, _ = get_namespace(X) + if _is_numpy_namespace(xp): + return xp.asarray(numpy.nanmax(X, axis=axis)) + + else: + mask = xp.isnan(X) + X = xp.max(xp.where(mask, xp.asarray(-xp.inf), X), axis=axis) + # Replace Infs from all NaN slices with NaN again + mask = xp.all(mask, axis=axis) + if xp.any(mask): + X = xp.where(mask, xp.asarray(xp.nan), X) + return X + + def _asarray_with_order(array, dtype=None, order=None, copy=None, *, xp=None): """Helper to support the order kwarg only for NumPy-backed arrays diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 28abf47103b25..682c052f0b39b 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy import pytest from numpy.testing import assert_allclose, assert_array_equal @@ -9,8 +11,11 @@ _asarray_with_order, _convert_to_numpy, _estimator_with_converted_arrays, + _nanmax, + _nanmin, _NumPyAPIWrapper, get_namespace, + supported_float_dtypes, ) from sklearn.utils._testing import skip_if_array_api_compat_not_configured @@ -159,6 +164,54 @@ def test_asarray_with_order_ignored(): assert not X_new_np.flags["F_CONTIGUOUS"] +@skip_if_array_api_compat_not_configured +@pytest.mark.parametrize( + "library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"] +) +@pytest.mark.parametrize( + "X,reduction,expected", + [ + ([1, 2, numpy.nan], _nanmin, 1), + ([1, -2, -numpy.nan], _nanmin, -2), + ([numpy.inf, numpy.inf], _nanmin, numpy.inf), + ( + [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]], + partial(_nanmin, axis=0), + [1.0, 2.0, 3.0], + ), + ( + [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]], + partial(_nanmin, axis=1), + [1.0, numpy.nan, 4.0], + ), + ([1, 2, numpy.nan], _nanmax, 2), + ([1, 2, numpy.nan], _nanmax, 2), + ([-numpy.inf, -numpy.inf], _nanmax, -numpy.inf), + ( + [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]], + partial(_nanmax, axis=0), + [4.0, 5.0, 6.0], + ), + ( + [[1, 2, 3], [numpy.nan, numpy.nan, numpy.nan], [4, 5, 6.0]], + partial(_nanmax, axis=1), + [3.0, numpy.nan, 6.0], + ), + ], +) +def test_nan_reductions(library, X, reduction, expected): + """Check NaN reductions like _nanmin and _nanmax""" + xp = pytest.importorskip(library) + + if isinstance(expected, list): + expected = xp.asarray(expected) + + with config_context(array_api_dispatch=True): + result = reduction(xp.asarray(X)) + + assert_allclose(result, expected) + + @skip_if_array_api_compat_not_configured @pytest.mark.parametrize("library", ["cupy", "torch", "cupy.array_api"]) def test_convert_to_numpy_gpu(library): # pragma: nocover @@ -256,6 +309,9 @@ def test_get_namespace_array_api_isdtype(wrapper): assert xp.isdtype(xp.float64, "real floating") assert not xp.isdtype(xp.int32, "real floating") + for dtype in supported_float_dtypes(xp): + assert xp.isdtype(dtype, "real floating") + assert xp.isdtype(xp.bool, "bool") assert not xp.isdtype(xp.float32, "bool")