From b8a969d6503362c4b476c78f6ba96e78068f6240 Mon Sep 17 00:00:00 2001 From: Ralf Gommers Date: Thu, 29 Feb 2024 12:29:09 +0100 Subject: [PATCH 01/11] MAINT: convert `numpy.array_api` to array-api-strict The `numpy.array_api` module has been converted into a standalone package (`array-api-strict`). This new package is stable and has had a 1.0 release. The `numpy.array_api` module was marked experimental, and will be removed for NumPy 2.0. Since `array-api-strict` works with both NumPy 1.2x and NumPy 2.0, using that for testing compliance to the syntax and semantics of the array API standard should always be preferred over still trying to use `numpy.array_api`. --- doc/modules/array_api.rst | 2 +- sklearn/decomposition/tests/test_pca.py | 2 +- sklearn/linear_model/_base.py | 2 +- sklearn/model_selection/_split.py | 9 ++++++ sklearn/model_selection/tests/test_split.py | 1 + sklearn/utils/__init__.py | 19 ++++++++--- sklearn/utils/_array_api.py | 33 +++++++++++++------- sklearn/utils/_testing.py | 8 +---- sklearn/utils/tests/test_array_api.py | 24 ++++++-------- sklearn/utils/tests/test_estimator_checks.py | 6 ++-- sklearn/utils/tests/test_validation.py | 2 +- 11 files changed, 65 insertions(+), 43 deletions(-) diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index ebc157fb169d1..a046881d84f29 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -25,7 +25,7 @@ At this stage, this support is **considered experimental** and must be enabled explicitly as explained in the following. .. note:: - Currently, only `cupy.array_api`, `numpy.array_api`, `cupy`, and `PyTorch` + Currently, only `cupy.array_api`, `array-api-strict`, `cupy`, and `PyTorch` are known to work with scikit-learn's estimators. Example usage diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 44281b9038697..6b0d2c456c18e 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -957,7 +957,7 @@ def test_pca_mle_array_api_compliance( def test_array_api_error_and_warnings_on_unsupported_params(): pytest.importorskip("array_api_compat") - xp = pytest.importorskip("numpy.array_api") + xp = pytest.importorskip("array_api_strict") iris_xp = xp.asarray(iris.data) pca = PCA(n_components=2, svd_solver="arpack", random_state=0) diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index f07e974542a5b..417ada8f7adbc 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -350,7 +350,7 @@ def predict(self, X): xp, _ = get_namespace(X) scores = self.decision_function(X) if len(scores.shape) == 1: - indices = xp.astype(scores > 0, int) + indices = xp.astype(scores > 0, xp.int64) else: indices = xp.argmax(scores, axis=1) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index a1c5194d1dbef..207271d632da2 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -29,6 +29,7 @@ indexable, metadata_routing, ) +from ..utils._array_api import get_namespace, device as array_api_device from ..utils._param_validation import Interval, RealNotInt, validate_params from ..utils.extmath import _approximate_mode from ..utils.metadata_routing import _MetadataRequester @@ -2787,6 +2788,14 @@ def train_test_split( train, test = next(cv.split(X=arrays[0], y=stratify)) + xp, is_array_api_compliant = get_namespace(arrays[0]) + if is_array_api_compliant: + # For indexing, the data and index arrays need to be of the same type + # and on the same device. + _device = array_api_device(arrays[0]) + train = xp.asarray(train, device=_device) + test = xp.asarray(test, device=_device) + return list( chain.from_iterable( (_safe_indexing(a, train), _safe_indexing(a, test)) for a in arrays diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 2afb9ae6adce7..4b24c18de08d7 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -1315,6 +1315,7 @@ def test_array_api_train_test_split( stratify_xp = xp.asarray(stratify) else: stratify_xp = stratify + X_train_xp, X_test_xp, y_train_xp, y_test_xp = train_test_split( X_xp, y_xp, shuffle=shuffle, stratify=stratify_xp, random_state=0 ) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 354d8240045ed..fffaf8e21fcee 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -16,6 +16,7 @@ from ..exceptions import DataConversionWarning from . import _joblib, metadata_routing +from ._array_api import get_namespace from ._bunch import Bunch from ._chunking import gen_batches, gen_even_slices from ._estimator_html_repr import estimator_html_repr @@ -92,6 +93,12 @@ def _array_indexing(array, key, key_dtype, axis): key = np.asarray(key) if isinstance(key, tuple): key = list(key) + if key_dtype == 'an-array-dtype': + # A non-numpy array for which determining the dtype kind wasn't + # possible. Use `take` over `__getitem__` to support integer indexing. + xp, _ = get_namespace(array) + if hasattr(xp, 'isdtype') and xp.isdtype(key.dtype, 'integral'): + return xp.take(array, key, axis=axis) return array[key, ...] if axis == 0 else array[:, key] @@ -197,10 +204,14 @@ def _determine_key_type(key, accept_slice=True): raise ValueError(err_msg) return key_type.pop() if hasattr(key, "dtype"): - try: - return array_dtype_to_str[key.dtype.kind] - except KeyError: - raise ValueError(err_msg) + if hasattr(key.dtype, "kind"): + try: + return array_dtype_to_str[key.dtype.kind] + except KeyError: + raise ValueError(err_msg) + else: + # A `dtype` for a non-numpy library may not have a kind attribute + return 'an-array-dtype' raise ValueError(err_msg) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 1131cb3560287..fe4dc0337fa10 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -34,8 +34,8 @@ def yield_namespace_device_dtype_combinations(): # tests are regular numpy arrays without any "device" attribute. "numpy", # Stricter NumPy-based Array API implementation. The - # numpy.array_api.Array instances always a dummy "device" attribute. - "numpy.array_api", + # array_api_strict.Array instances always a dummy "device" attribute. + "array_api_strict", "cupy", "cupy.array_api", "torch", @@ -110,7 +110,7 @@ def size(x): def _is_numpy_namespace(xp): """Return True if xp is backed by NumPy.""" - return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"} + return xp.__name__ in {"numpy", "array_api_compat.numpy"} def _union1d(a, b, xp): @@ -149,7 +149,6 @@ def _isdtype_single(dtype, kind, *, xp): return dtype in supported_float_dtypes(xp) elif kind == "complex floating": # Some name spaces do not have complex, such as cupy.array_api - # and numpy.array_api complex_dtypes = set() if hasattr(xp, "complex64"): complex_dtypes.add(xp.complex64) @@ -226,8 +225,8 @@ def wrapped_func(*args, **kwargs): class _NumPyAPIWrapper: """Array API compat wrapper for any numpy version - NumPy < 1.22 does not expose the numpy.array_api namespace. This - wrapper makes it possible to write code that uses the standard + NumPy < 1.22 and >=2.0 do not expose the numpy.array_api namespace. + This wrapper makes it possible to write code that uses the standard Array API while working with any version of NumPy supported by scikit-learn. @@ -336,7 +335,7 @@ def get_namespace(*arrays): Introspect `arrays` arguments and return their common Array API compatible namespace object, if any. NumPy 1.22 and later can - construct such containers using the `numpy.array_api` namespace + construct such containers using the `array_api_strict` namespace for instance. See: https://numpy.org/neps/nep-0047-array-api-standard.html @@ -390,7 +389,7 @@ def get_namespace(*arrays): # These namespaces need additional wrapping to smooth out small differences # between implementations - if namespace.__name__ in {"numpy.array_api", "cupy.array_api"}: + if namespace.__name__ in {"array_api_strict", "cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) return namespace, is_array_api_compliant @@ -440,9 +439,21 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): return float(numpy.average(sample_score_np, weights=sample_weight_np)) if not xp.isdtype(sample_score.dtype, "real floating"): - # We move to cpu device ahead of time since certain devices may not support - # float64, but we want the same precision for all devices and namespaces. - sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64) + # We move to cpu device since certain devices may not support float64, + # but we want the same precision for all devices and namespaces. + try: + sample_score = xp.astype(sample_score, xp.float64) + except Exception: + if xp.__name__ not in {"numpy", "array_api_compat.numpy", + "cupy", "cupy.array_api", "array_api_strict"}: + # XXX: the 'cpu' string isn't portable. Once libraries support the + # 2023.12 version of the standard, there is an introspection API + # that can be used to check if float64 is supported on the current + # device (if it is, the try-except can be avoided here). And the + # copy/device keywords to `from_dlpack` can be used for + # portable "copy to host" behavior.) + sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), + xp.float64) if sample_weight is not None: sample_weight = xp.asarray( diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index c78765cb1ba5e..e5601a11ffa62 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -1049,13 +1049,7 @@ def fit_transform(self, X, y=None): def _array_api_for_tests(array_namespace, device): try: - if array_namespace == "numpy.array_api": - # FIXME: once it is not experimental anymore - with ignore_warnings(category=UserWarning): - # UserWarning: numpy.array_api submodule is still experimental. - array_mod = importlib.import_module(array_namespace) - else: - array_mod = importlib.import_module(array_namespace) + array_mod = importlib.import_module(array_namespace) except ModuleNotFoundError: raise SkipTest( f"{array_namespace} is not installed: not checking array_api input" diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 1df81cf823bd6..6051881f2169b 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -25,10 +25,6 @@ skip_if_array_api_compat_not_configured, ) -pytestmark = pytest.mark.filterwarnings( - "ignore:The numpy.array_api submodule:UserWarning" -) - @pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]]) def test_get_namespace_ndarray_default(X): @@ -66,7 +62,7 @@ def test_get_namespace_ndarray_with_dispatch(): @skip_if_array_api_compat_not_configured def test_get_namespace_array_api(): """Test get_namespace for ArrayAPI arrays.""" - xp = pytest.importorskip("numpy.array_api") + xp = pytest.importorskip("array_api_strict") X_np = numpy.asarray([[1, 2, 3]]) X_xp = xp.asarray(X_np) @@ -89,8 +85,8 @@ def __init__(self, array_namespace, name): def test_array_api_wrapper_astype(): """Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy.""" - numpy_array_api = pytest.importorskip("numpy.array_api") - xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api") + array_api_strict = pytest.importorskip("array_api_strict") + xp_ = _AdjustableNameAPITestWrapper(array_api_strict, "array_api_strict") xp = _ArrayAPIWrapper(xp_) X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64) @@ -101,7 +97,7 @@ def test_array_api_wrapper_astype(): assert X_converted.dtype == xp.float32 -@pytest.mark.parametrize("array_api", ["numpy", "numpy.array_api"]) +@pytest.mark.parametrize("array_api", ["numpy", "array_api_strict"]) def test_asarray_with_order(array_api): """Test _asarray_with_order passes along order for NumPy arrays.""" xp = pytest.importorskip(array_api) @@ -115,8 +111,8 @@ def test_asarray_with_order(array_api): def test_asarray_with_order_ignored(): """Test _asarray_with_order ignores order for Generic ArrayAPI.""" - xp = pytest.importorskip("numpy.array_api") - xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api") + xp = pytest.importorskip("array_api_strict") + xp_ = _AdjustableNameAPITestWrapper(xp, "array_api_strict") X = numpy.asarray([[1.2, 3.4, 5.1], [3.4, 5.5, 1.2]], order="C") X = xp_.asarray(X) @@ -161,7 +157,7 @@ def test_weighted_sum( @skip_if_array_api_compat_not_configured @pytest.mark.parametrize( - "library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"] + "library", ["numpy", "array_api_strict", "cupy", "cupy.array_api", "torch"] ) @pytest.mark.parametrize( "X,reduction,expected", @@ -247,7 +243,7 @@ def fit(self, X, y=None): "array_namespace, converter", [ ("torch", lambda array: array.cpu().numpy()), - ("numpy.array_api", lambda array: numpy.asarray(array)), + ("array_api_strict", lambda array: numpy.asarray(array)), ("cupy.array_api", lambda array: array._array.get()), ], ) @@ -265,7 +261,7 @@ def test_convert_estimator_to_ndarray(array_namespace, converter): @skip_if_array_api_compat_not_configured def test_convert_estimator_to_array_api(): """Convert estimator attributes to ArrayAPI arrays.""" - xp = pytest.importorskip("numpy.array_api") + xp = pytest.importorskip("array_api_strict") X_np = numpy.asarray([[1.3, 4.5]]) est = SimpleEstimator().fit(X_np) @@ -294,7 +290,7 @@ def test_get_namespace_array_api_isdtype(wrapper): """Test isdtype implementation from _ArrayAPIWrapper and _NumPyAPIWrapper.""" if wrapper == _ArrayAPIWrapper: - xp_ = pytest.importorskip("numpy.array_api") + xp_ = pytest.importorskip("array_api_strict") xp = _ArrayAPIWrapper(xp_) else: xp = _NumPyAPIWrapper() diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 1e0a083a9c989..8ac7ac9db2e9a 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -529,15 +529,15 @@ def test_check_array_api_input(): except ModuleNotFoundError: raise SkipTest("array_api_compat is required to run this test") try: - importlib.import_module("numpy.array_api") + importlib.import_module("array_api_strict") except ModuleNotFoundError: # pragma: nocover - raise SkipTest("numpy.array_api is required to run this test") + raise SkipTest("array-api-strict is required to run this test") with raises(AssertionError, match="Not equal to tolerance"): check_array_api_input( "BrokenArrayAPI", BrokenArrayAPI(), - array_namespace="numpy.array_api", + array_namespace="array_api_strict", check_values=True, ) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 292c2cbbde95d..523fe041599dd 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1973,7 +1973,7 @@ def test_pandas_array_returns_ndarray(input_values): @skip_if_array_api_compat_not_configured -@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"]) +@pytest.mark.parametrize("array_namespace", ["array_api_strict", "cupy.array_api"]) def test_check_array_array_api_has_non_finite(array_namespace): """Checks that Array API arrays checks non-finite correctly.""" xp = pytest.importorskip(array_namespace) From 9e482ea543769250b320e31ef89c74860b63739d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 1 Mar 2024 11:00:24 +0100 Subject: [PATCH 02/11] apply linters --- sklearn/model_selection/_split.py | 3 ++- sklearn/utils/__init__.py | 6 +++--- sklearn/utils/_array_api.py | 14 ++++++++++---- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 207271d632da2..a77a7f02d03c6 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -29,7 +29,8 @@ indexable, metadata_routing, ) -from ..utils._array_api import get_namespace, device as array_api_device +from ..utils._array_api import device as array_api_device +from ..utils._array_api import get_namespace from ..utils._param_validation import Interval, RealNotInt, validate_params from ..utils.extmath import _approximate_mode from ..utils.metadata_routing import _MetadataRequester diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index fffaf8e21fcee..a563ea470d7d7 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -93,11 +93,11 @@ def _array_indexing(array, key, key_dtype, axis): key = np.asarray(key) if isinstance(key, tuple): key = list(key) - if key_dtype == 'an-array-dtype': + if key_dtype == "an-array-dtype": # A non-numpy array for which determining the dtype kind wasn't # possible. Use `take` over `__getitem__` to support integer indexing. xp, _ = get_namespace(array) - if hasattr(xp, 'isdtype') and xp.isdtype(key.dtype, 'integral'): + if hasattr(xp, "isdtype") and xp.isdtype(key.dtype, "integral"): return xp.take(array, key, axis=axis) return array[key, ...] if axis == 0 else array[:, key] @@ -211,7 +211,7 @@ def _determine_key_type(key, accept_slice=True): raise ValueError(err_msg) else: # A `dtype` for a non-numpy library may not have a kind attribute - return 'an-array-dtype' + return "an-array-dtype" raise ValueError(err_msg) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index fe4dc0337fa10..c174a7d9d26b6 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -444,16 +444,22 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): try: sample_score = xp.astype(sample_score, xp.float64) except Exception: - if xp.__name__ not in {"numpy", "array_api_compat.numpy", - "cupy", "cupy.array_api", "array_api_strict"}: + if xp.__name__ not in { + "numpy", + "array_api_compat.numpy", + "cupy", + "cupy.array_api", + "array_api_strict", + }: # XXX: the 'cpu' string isn't portable. Once libraries support the # 2023.12 version of the standard, there is an introspection API # that can be used to check if float64 is supported on the current # device (if it is, the try-except can be avoided here). And the # copy/device keywords to `from_dlpack` can be used for # portable "copy to host" behavior.) - sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), - xp.float64) + sample_score = xp.astype( + xp.asarray(sample_score, device="cpu"), xp.float64 + ) if sample_weight is not None: sample_weight = xp.asarray( From ef13536589125fe80de3b72c1f91cc8172144821 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 12 Mar 2024 14:01:20 +0100 Subject: [PATCH 03/11] Apply suggestions from code review Co-authored-by: Tim Head --- sklearn/model_selection/_split.py | 6 +++--- sklearn/utils/_array_api.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index a77a7f02d03c6..6176aeb71b44b 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -2793,9 +2793,9 @@ def train_test_split( if is_array_api_compliant: # For indexing, the data and index arrays need to be of the same type # and on the same device. - _device = array_api_device(arrays[0]) - train = xp.asarray(train, device=_device) - test = xp.asarray(test, device=_device) + device_ = array_api_device(arrays[0]) + train = xp.asarray(train, device=device_) + test = xp.asarray(test, device=device_) return list( chain.from_iterable( diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 45085dbd7e4ff..5eee0a22fc1ab 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -42,7 +42,7 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True): # tests are regular numpy arrays without any "device" attribute. "numpy", # Stricter NumPy-based Array API implementation. The - # array_api_strict.Array instances always a dummy "device" attribute. + # array_api_strict.Array instances always have a dummy "device" attribute. "array_api_strict", "cupy", "cupy.array_api", From 72a0d12056ce3428c04477720a028737b7d580b8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 13 Mar 2024 11:36:04 +0100 Subject: [PATCH 04/11] MAINT add the array-api-strict soft dependency to one of the CI configs --- .../pylatest_conda_forge_mkl_linux-64_conda.lock | 15 ++++++++------- ...atest_conda_forge_mkl_linux-64_environment.yml | 1 + build_tools/update_environments_and_lock_files.py | 1 + 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock index beb91e614c3ee..fc29fbdb2403d 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 58f8e7b7af3826532872962cc6b88e6c57c2ed2b033142dd0dfe788d2ad041b8 +# input_hash: 08f8df7e644bbce9fdef170aced0b7954ead381d143b9827f973def270f040b3 @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2024.2.2-hbcca054_0.conda#2f4327a1cbe7f022401b236e915a5fef @@ -85,7 +85,7 @@ https://conda.anaconda.org/conda-forge/linux-64/libgpg-error-1.48-h71f35ed_0.con https://conda.anaconda.org/conda-forge/linux-64/libnghttp2-1.58.0-h47da74e_1.conda#700ac6ea6d53d5510591c4344d5c989a https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.43-h2797004_0.conda#009981dd9cfcaa4dbfa25ffaed86bcae https://conda.anaconda.org/conda-forge/linux-64/libprotobuf-3.21.12-hfc55251_2.conda#e3a7d4ba09b8dc939b98fef55f539220 -https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.45.1-h2797004_0.conda#fc4ccadfbf6d4784de88c41704792562 +https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.45.2-h2797004_0.conda#866983a220e27a80cb75e85cb30466a1 https://conda.anaconda.org/conda-forge/linux-64/libssh2-1.11.0-h0841786_0.conda#1f5a58e686b13bcfde88b93f547d23fe https://conda.anaconda.org/conda-forge/linux-64/libvorbis-1.3.7-h9c3ff4c_0.tar.bz2#309dec04b70a3cc0f1e84a4013683bc0 https://conda.anaconda.org/conda-forge/linux-64/libxcb-1.15-h0b41bf4_0.conda#33277193f5b92bad9fdd230eb700929c @@ -104,7 +104,7 @@ https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.0.9-h166bdaf_9.cond https://conda.anaconda.org/conda-forge/linux-64/freetype-2.12.1-h267a509_2.conda#9ae35c3d96db2c94ce0cef86efdfa2cb https://conda.anaconda.org/conda-forge/linux-64/krb5-1.21.2-h659d440_0.conda#cd95826dbd331ed1be26bdf401432844 https://conda.anaconda.org/conda-forge/linux-64/libgcrypt-1.10.3-hd590300_0.conda#32d16ad533c59bb0a3c5ffaf16110829 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.78.4-hf2295e7_4.conda#33c2d37e9460579b5b9a7194e5637884 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.80.0-hf2295e7_0.conda#6c0d5a4f5292e54bf9b8dc14ee7df448 https://conda.anaconda.org/conda-forge/linux-64/libgrpc-1.54.3-hb20ce57_0.conda#7af7c59ab24db007dfd82e0a3a343f66 https://conda.anaconda.org/conda-forge/linux-64/libhiredis-1.0.2-h2cc385e_0.tar.bz2#b34907d3a81a3cd8095ee83d174c074a https://conda.anaconda.org/conda-forge/linux-64/libhwloc-2.9.3-default_h554bfaf_1009.conda#f36ddc11ca46958197a45effdd286e45 @@ -135,7 +135,7 @@ https://conda.anaconda.org/conda-forge/linux-64/dbus-1.13.6-h5008d03_3.tar.bz2#e https://conda.anaconda.org/conda-forge/noarch/exceptiongroup-1.2.0-pyhd8ed1ab_2.conda#8d652ea2ee8eaee02ed8dc820bc794aa https://conda.anaconda.org/conda-forge/noarch/execnet-2.0.2-pyhd8ed1ab_0.conda#67de0d8241e1060a479e3c37793e26f9 https://conda.anaconda.org/conda-forge/linux-64/fontconfig-2.14.2-h14ed4e7_0.conda#0f69b688f52ff6da70bccb7ff7001d1d -https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.78.4-hfc55251_4.conda#a9d7db7ef34f21ca805bdcab32880081 +https://conda.anaconda.org/conda-forge/linux-64/glib-tools-2.80.0-hde27a5a_0.conda#072608f7b71755993a294c3cdf909fa6 https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_0.conda#f800d2da156d08e289b14e87e43c1ae5 https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.5-py311h9547e67_1.conda#2c65bdf442b0d37aad080c8a4e0d452f https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.16-hb7c19ff_0.conda#51bb7010fc86f70eee639b4bb7a894f5 @@ -170,12 +170,12 @@ https://conda.anaconda.org/conda-forge/linux-64/aws-c-mqtt-0.9.3-hb447be9_1.cond https://conda.anaconda.org/conda-forge/linux-64/cairo-1.18.0-h3faef2a_0.conda#f907bb958910dc404647326ca80c263e https://conda.anaconda.org/conda-forge/linux-64/coverage-7.4.3-py311h459d7ec_1.conda#4fb7f674bf6839da62317a7c6e725c55 https://conda.anaconda.org/conda-forge/linux-64/fonttools-4.49.0-py311h459d7ec_0.conda#d66c9e36ab104f94e35b015c86c2fcb4 -https://conda.anaconda.org/conda-forge/linux-64/glib-2.78.4-hfc55251_4.conda#1ea9d57c3e9a56338a503513b1264c3d +https://conda.anaconda.org/conda-forge/linux-64/glib-2.80.0-hf2295e7_0.conda#56a4c2ed7723cf0847752f74dba1929f https://conda.anaconda.org/conda-forge/noarch/joblib-1.3.2-pyhd8ed1ab_0.conda#4da50d410f553db77e62ab62ffaa1abc https://conda.anaconda.org/conda-forge/linux-64/libclang-15.0.7-default_hb11cfb5_4.conda#c90f4cbb57839c98fef8f830e4b9972f https://conda.anaconda.org/conda-forge/linux-64/libgoogle-cloud-2.12.0-hac9eb74_1.conda#0dee716254497604762957076ac76540 https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.6.0-hd429924_1.conda#1dbcc04604fdf1e526e6d1b0b6938396 -https://conda.anaconda.org/conda-forge/noarch/meson-1.3.2-pyhd8ed1ab_0.conda#8d18c47cc233a35c81450ba1ce601eb4 +https://conda.anaconda.org/conda-forge/noarch/meson-1.4.0-pyhd8ed1ab_0.conda#52a0660cfa40b45bf254ecc3374cb2e0 https://conda.anaconda.org/conda-forge/linux-64/mkl-2022.2.1-h84fe81f_16997.conda#a7ce56d5757f5b57e7daabe703ade5bb https://conda.anaconda.org/conda-forge/linux-64/pillow-10.2.0-py311ha6c5da5_0.conda#a5ccd7f2271f28b7d2de0b02b64e3796 https://conda.anaconda.org/conda-forge/noarch/pip-24.0-pyhd8ed1ab_0.conda#f586ac1e56c8638b64f9c8122a7b8a67 @@ -200,10 +200,11 @@ https://conda.anaconda.org/conda-forge/linux-64/liblapack-3.9.0-16_linux64_mkl.t https://conda.anaconda.org/conda-forge/linux-64/aws-sdk-cpp-1.10.57-h85b1a90_19.conda#0605d3d60857fc07bd6a11e878fe0f08 https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py311h64a7726_0.conda#a502d7aad449a1206efb366d6a12c52d https://conda.anaconda.org/conda-forge/linux-64/qt-main-5.15.8-h5810be5_19.conda#54866f708d43002a514d0b9b0f84bc11 +https://conda.anaconda.org/conda-forge/noarch/array-api-strict-1.1-pyhd8ed1ab_0.conda#db4260fac4412db30bd5213b6c9f6ecc https://conda.anaconda.org/conda-forge/linux-64/contourpy-1.2.0-py311h9547e67_0.conda#40828c5b36ef52433e21f89943e09f33 https://conda.anaconda.org/conda-forge/linux-64/libarrow-12.0.1-hb87d912_8_cpu.conda#3f3b11398fe79b578e3c44dd00a44e4a https://conda.anaconda.org/conda-forge/linux-64/pandas-2.2.1-py311h320fe9a_0.conda#aac8d7137fedc2fd5f8320bf50e4204c -https://conda.anaconda.org/conda-forge/linux-64/polars-0.20.14-py311ha78dd58_0.conda#89f481d3216789459de840a5f104dc04 +https://conda.anaconda.org/conda-forge/linux-64/polars-0.20.15-py311h78b473b_0.conda#92bdda09eefc2c22d1aed663b682ba9c https://conda.anaconda.org/conda-forge/linux-64/pyqt-5.15.9-py311hf0fb5b6_5.conda#ec7e45bc76d9d0b69a74a2075932b8e8 https://conda.anaconda.org/conda-forge/linux-64/pytorch-1.13.1-cpu_py311h410fd25_1.conda#ddd2fadddf89e3dc3d541a2537fce010 https://conda.anaconda.org/conda-forge/linux-64/scipy-1.12.0-py311h64a7726_2.conda#24ca5107ab75c5521067b8ba505dfae5 diff --git a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml index 8ad97b91f8fce..f4dc622417040 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml +++ b/build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml @@ -28,3 +28,4 @@ dependencies: - polars - pyarrow - array-api-compat + - array-api-strict diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index add176b93ba82..dbb23f53a362b 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -104,6 +104,7 @@ def remove_from(alist, to_remove): "polars", "pyarrow", "array-api-compat", + "array-api-strict", ], "package_constraints": { "blas": "[build=mkl]", From f268cf73749c73df73df33cefc3bed797fb8defb Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 14 Mar 2024 11:24:25 +0100 Subject: [PATCH 05/11] Merge main + update _array_indexing --- sklearn/utils/_indexing.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/_indexing.py b/sklearn/utils/_indexing.py index 3938d9b4a4c2a..49e04a451dcb0 100644 --- a/sklearn/utils/_indexing.py +++ b/sklearn/utils/_indexing.py @@ -6,6 +6,7 @@ import numpy as np from scipy.sparse import issparse +from ._array_api import _is_numpy_namespace, get_namespace from ._param_validation import Interval, validate_params from .extmath import _approximate_mode from .validation import ( @@ -21,6 +22,9 @@ def _array_indexing(array, key, key_dtype, axis): """Index an array or scipy.sparse consistently across NumPy version.""" + xp, is_array_api = get_namespace(array) + if is_array_api: + return xp.take(array, key, axis=axis) if issparse(array) and key_dtype == "bool": key = np.asarray(key) if isinstance(key, tuple): @@ -147,10 +151,22 @@ def _determine_key_type(key, accept_slice=True): raise ValueError(err_msg) return key_type.pop() if hasattr(key, "dtype"): - try: - return array_dtype_to_str[key.dtype.kind] - except KeyError: - raise ValueError(err_msg) + xp, is_array_api = get_namespace(key) + # NumPy arrays are special-cased in their own branch because the Array API + # cannot handle object/string-based dtypes that are often used to index + # columns of dataframes by names. + if is_array_api and not _is_numpy_namespace(xp): + if xp.isdtype(key.dtype, "bool"): + return "bool" + elif xp.isdtype(key.dtype, "integral"): + return "int" + else: + raise ValueError(err_msg) + else: + try: + return array_dtype_to_str[key.dtype.kind] + except KeyError: + raise ValueError(err_msg) raise ValueError(err_msg) From 46a3936b62ffdb0bc0d5bdc3f423e7bb02cc1e26 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 14 Mar 2024 11:48:00 +0100 Subject: [PATCH 06/11] Introduce and use indexing_dtype(xp) --- sklearn/linear_model/_base.py | 4 ++-- sklearn/utils/_array_api.py | 24 ++++++++++++++++++++++++ sklearn/utils/tests/test_array_api.py | 14 ++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 417ada8f7adbc..be8c9097332eb 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -33,7 +33,7 @@ _fit_context, ) from ..utils import check_array, check_random_state -from ..utils._array_api import get_namespace +from ..utils._array_api import get_namespace, indexing_dtype from ..utils._seq_dataset import ( ArrayDataset32, ArrayDataset64, @@ -350,7 +350,7 @@ def predict(self, X): xp, _ = get_namespace(X) scores = self.decision_function(X) if len(scores.shape) == 1: - indices = xp.astype(scores > 0, xp.int64) + indices = xp.astype(scores > 0, indexing_dtype(xp)) else: indices = xp.argmax(scores, axis=1) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 5eee0a22fc1ab..fa96a4dc4f43f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -733,3 +733,27 @@ def _estimator_with_converted_arrays(estimator, converter): def _atol_for_type(dtype): """Return the absolute tolerance for a given numpy dtype.""" return numpy.finfo(dtype).eps * 100 + + +def indexing_dtype(xp): + """Return a platform-specific integer dtype suitable for indexing. + + On 32-bit platforms, this will typically return int32 and int64 otherwise. + + Note: using dtype is recommended for indexing transient array + datastructures. For long-lived arrays, such as the fitted attributes of + estimators, it is instead recommended to use platform-independent int32 if + we do not expect to index more 2B elements. Using fixed dtypes simplifies + the handling of serialized models, e.g. to deploy a model fit on a 64-bit + platform to a target 32-bit platform such as WASM/pyodide. + """ + # Currently this is implemented with simple hack that assumes that + # following "may be" statements in the Array API spec always hold: + # > The default integer data type should be the same across platforms, but + # > the default may vary depending on whether Python is 32-bit or 64-bit. + # > The default array index data type may be int32 on 32-bit platforms, but + # > the default should be int64 otherwise. + # https://data-apis.org/array-api/latest/API_specification/data_types.html#default-data-types + # TODO: once sufficiently adopted, we might want to instead rely on the + # newer inspection API: https://github.com/data-apis/array-api/issues/640 + return xp.asarray(0).dtype diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index c79578aa035c7..5284899564f65 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -7,6 +7,7 @@ from sklearn._config import config_context from sklearn.base import BaseEstimator +from sklearn.utils import _IS_32BIT from sklearn.utils._array_api import ( _ArrayAPIWrapper, _asarray_with_order, @@ -19,6 +20,7 @@ _NumPyAPIWrapper, device, get_namespace, + indexing_dtype, supported_float_dtypes, yield_namespace_device_dtype_combinations, ) @@ -469,3 +471,15 @@ def test_get_namespace_array_api_isdtype(wrapper): with pytest.raises(ValueError, match="Unrecognized data type"): assert xp.isdtype(xp.int16, "unknown") + + +@pytest.mark.parametrize( + "namespace, _device, _dtype", yield_namespace_device_dtype_combinations() +) +def test_indexing_dtype(namespace, _device, _dtype): + xp = _array_api_for_tests(namespace, _device) + + if _IS_32BIT: + assert indexing_dtype(xp) == xp.int32 + else: + assert indexing_dtype(xp) == xp.int64 From 1aa49e74faae31f44bb72d59a6d17ab76e8225fc Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 14 Mar 2024 11:59:32 +0100 Subject: [PATCH 07/11] DOC better reflect Array API compliance ojectives in _NumPyAPIWrapper --- sklearn/utils/_array_api.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index fa96a4dc4f43f..7f88c48e1e88e 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -270,14 +270,20 @@ def wrapped_func(*args, **kwargs): class _NumPyAPIWrapper: """Array API compat wrapper for any numpy version - NumPy < 1.22 and >=2.0 do not expose the numpy.array_api namespace. - This wrapper makes it possible to write code that uses the standard - Array API while working with any version of NumPy supported by - scikit-learn. + NumPy < 2 does not implement the namespace. NumPy 2 and later should + progressively implement more an more of the latest Array API spec but this + is still work in progress at this time. + + This wrapper makes it possible to write code that uses the standard Array + API while working with any version of NumPy supported by scikit-learn. See the `get_namespace()` public function for more details. """ + # TODO: once scikit-learn drops support for NumPy < 2, this class can be + # removed, assuming Array API compliance of NumPy 2 is actually sufficient + # for scikit-learn's needs. + # Creation functions in spec: # https://data-apis.org/array-api/latest/API_specification/creation_functions.html _CREATION_FUNCS = { From 71ee9bff0087cea90fa2be9bc9edf5efb4ebde75 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 14 Mar 2024 12:05:29 +0100 Subject: [PATCH 08/11] More docstring updates and do not wrap array_api_strict with _ArrayAPIWrapper --- sklearn/utils/_array_api.py | 23 ++++++++++------------- sklearn/utils/tests/test_array_api.py | 1 - 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 7f88c48e1e88e..ed4d6bb05ac8b 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -423,18 +423,16 @@ def _remove_non_arrays(*arrays, remove_none=True, remove_types=(str,)): def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): """Get namespace of arrays. - Introspect `arrays` arguments and return their common Array API - compatible namespace object, if any. NumPy 1.22 and later can - construct such containers using the `array_api_strict` namespace - for instance. + Introspect `arrays` arguments and return their common Array API compatible + namespace object, if any. It is possible to construct such containers using + the `array_api_strict` namespace for instance. See: https://numpy.org/neps/nep-0047-array-api-standard.html - If `arrays` are regular numpy arrays, an instance of the - `_NumPyAPIWrapper` compatibility wrapper is returned instead. + If `arrays` are regular numpy arrays, an instance of the `_NumPyAPIWrapper` + compatibility wrapper is returned instead. - Namespace support is not enabled by default. To enabled it - call: + Namespace support is not enabled by default. To enabled it call: sklearn.set_config(array_api_dispatch=True) @@ -443,10 +441,9 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): with sklearn.config_context(array_api_dispatch=True): # your code here - Otherwise an instance of the `_NumPyAPIWrapper` - compatibility wrapper is always returned irrespective of - the fact that arrays implement the `__array_namespace__` - protocol or not. + Otherwise an instance of the `_NumPyAPIWrapper` compatibility wrapper is + always returned irrespective of the fact that arrays implement the + `__array_namespace__` protocol or not. Parameters ---------- @@ -500,7 +497,7 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): # These namespaces need additional wrapping to smooth out small differences # between implementations - if namespace.__name__ in {"array_api_strict", "cupy.array_api"}: + if namespace.__name__ in {"cupy.array_api"}: namespace = _ArrayAPIWrapper(namespace) return namespace, is_array_api_compliant diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 5284899564f65..98ba444bf98b0 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -73,7 +73,6 @@ def test_get_namespace_array_api(): with config_context(array_api_dispatch=True): xp_out, is_array_api_compliant = get_namespace(X_xp) assert is_array_api_compliant - assert isinstance(xp_out, _ArrayAPIWrapper) with pytest.raises(TypeError): xp_out, is_array_api_compliant = get_namespace(X_xp, X_np) From 9308c3b37603ace8c5f369123b4b846f59136961 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 18 Mar 2024 16:08:01 +0100 Subject: [PATCH 09/11] Remove blank line introduced when resolving conflict. --- sklearn/model_selection/tests/test_split.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index 4b24c18de08d7..2afb9ae6adce7 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -1315,7 +1315,6 @@ def test_array_api_train_test_split( stratify_xp = xp.asarray(stratify) else: stratify_xp = stratify - X_train_xp, X_test_xp, y_train_xp, y_test_xp = train_test_split( X_xp, y_xp, shuffle=shuffle, stratify=stratify_xp, random_state=0 ) From d82ae16dc99dda1c3fe10f2e6a7f922bb81480c8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 18 Mar 2024 16:35:07 +0100 Subject: [PATCH 10/11] Revert update made to pylatest_conda_forge_mkl_osx-64_conda.lock --- build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock index 8221e0e9c1b36..fb1268d7fb700 100644 --- a/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock +++ b/build_tools/azure/pylatest_conda_forge_mkl_osx-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: osx-64 -# input_hash: 1a426ea210e386d35f7d10d1994232053aaddcffe015b7c418298385f796c6e5 +# input_hash: 8d19b3cb048dd1e254e00f21d81841feddd52c98a15661153cb472e9903b5cb3 @EXPLICIT https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-h10d778d_5.conda#6097a6ca9ada32699b5fc4312dd6ef18 https://conda.anaconda.org/conda-forge/osx-64/ca-certificates-2024.2.2-h8857fd0_0.conda#f2eacee8c33c43692f1ccfd33d0f50b1 From b4befe91d6be0d918900b776bd15737b50aad79b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 18 Mar 2024 17:21:16 +0100 Subject: [PATCH 11/11] Simplify docstring. Co-authored-by: Tim Head --- sklearn/utils/_array_api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index f3349aabc285e..70e210e7e913e 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -453,8 +453,7 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): """Get namespace of arrays. Introspect `arrays` arguments and return their common Array API compatible - namespace object, if any. It is possible to construct such containers using - the `array_api_strict` namespace for instance. + namespace object, if any. See: https://numpy.org/neps/nep-0047-array-api-standard.html