diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 065be410a6273..2eb61d3875859 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -56,6 +56,12 @@ See :ref:`array_api` for more details. compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim Head ` and :user:`Olivier Grisel `. +Other changes impacting array API support in general: + +- |Fix| func:`validation.check_array` now accepts scipy sparse inputs without error + even when array API dispatch is enabled. + :pr:`29466` by :user:`Olivier Grisel `. + Metadata Routing ---------------- diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a00d250ab31d2..5d172a021233f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -5,6 +5,7 @@ from functools import wraps import numpy +import scipy.sparse as sp import scipy.special as special from .._config import get_config @@ -550,6 +551,19 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): # message in case it is missing. import array_api_compat + # Special handling of scipy sparse inputs, potentially mixed with dense + # arrays. + if all(sp.issparse(a) for a in arrays): + # For all-scipy sparse inputs, it's safe to assume that the namespace + # is numpy (or its array-api-compat wrapper for versions before 2.0). + return array_api_compat.get_namespace(numpy.empty(shape=0)), True + else: + # For mixed dense/sparse array inputs, ignore the sparse arrays and + # proceed only with the dense ones. The caller code in scikit-learn + # should always be in charge of special handling scipy sparse + # operations under sp.issparse() branches when needed. + arrays = [a for a in arrays if not sp.issparse(a)] + namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True # These namespaces need additional wrapping to smooth out small differences diff --git a/sklearn/utils/tests/test_array_api.py b/sklearn/utils/tests/test_array_api.py index 71f499f7a8dae..e3695b17c5268 100644 --- a/sklearn/utils/tests/test_array_api.py +++ b/sklearn/utils/tests/test_array_api.py @@ -3,6 +3,7 @@ import numpy import pytest +import scipy.sparse as sp from numpy.testing import assert_allclose from sklearn._config import config_context @@ -60,19 +61,31 @@ def test_get_namespace_ndarray_creation_device(): @skip_if_array_api_compat_not_configured def test_get_namespace_ndarray_with_dispatch(): """Test get_namespace on NumPy ndarrays.""" - array_api_compat = pytest.importorskip("array_api_compat") + array_api_compat_numpy = pytest.importorskip("array_api_compat.numpy") X_np = numpy.asarray([[1, 2, 3]]) + if np_version >= parse_version("2.0.0"): + # NumPy 2.0+ is an array API compliant library. + expected_namespace = numpy + else: + # Older NumPy versions require the compatibility layer. + expected_namespace = array_api_compat_numpy with config_context(array_api_dispatch=True): xp_out, is_array_api_compliant = get_namespace(X_np) assert is_array_api_compliant - if np_version >= parse_version("2.0.0"): - # NumPy 2.0+ is an array API compliant library. - assert xp_out is numpy - else: - # Older NumPy versions require the compatibility layer. - assert xp_out is array_api_compat.numpy + assert xp_out is expected_namespace + + X_sp = sp.csr_array(X_np) + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X_sp) + assert is_array_api_compliant + assert xp_out is expected_namespace + + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X_np, X_sp) + assert is_array_api_compliant + assert xp_out is expected_namespace @skip_if_array_api_compat_not_configured @@ -82,13 +95,20 @@ def test_get_namespace_array_api(): X_np = numpy.asarray([[1, 2, 3]]) X_xp = xp.asarray(X_np) + X_sp = sp.csr_array(X_np) + with config_context(array_api_dispatch=True): xp_out, is_array_api_compliant = get_namespace(X_xp) assert is_array_api_compliant + assert xp_out is xp with pytest.raises(TypeError): xp_out, is_array_api_compliant = get_namespace(X_xp, X_np) + xp_out, is_array_api_compliant = get_namespace(X_xp, X_sp) + assert is_array_api_compliant + assert xp_out is xp + class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper): """API wrapper that has an adjustable name. Used for testing.""" diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 5bde51ae514d9..afc3d743a95d6 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2018,6 +2018,16 @@ def test_check_array_array_api_has_non_finite(array_namespace): check_array(X_inf) +@skip_if_array_api_compat_not_configured +def test_check_array_on_sparse_inputs_with_array_api_enabled(): + X_sp = sp.csr_array([[0, 1, 0], [1, 0, 1]]) + with config_context(array_api_dispatch=True): + assert sp.issparse(check_array(X_sp, accept_sparse=True)) + + with pytest.raises(TypeError): + check_array(X_sp) + + @pytest.mark.parametrize( "extension_dtype, regular_dtype", [