diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index d7d3a71eba636..fb3563736577e 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -56,6 +56,10 @@ See :ref:`array_api` for more details. compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim Head ` and :user:`Olivier Grisel `. +- |Fix| func:`validation.check_array` now accepts scipy sparse inputs without error + even when array API dispatch is enabled. + :pr:`29469` by :user:`Olivier Grisel `. + Metadata Routing ---------------- diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 5bde51ae514d9..afc3d743a95d6 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2018,6 +2018,16 @@ def test_check_array_array_api_has_non_finite(array_namespace): check_array(X_inf) +@skip_if_array_api_compat_not_configured +def test_check_array_on_sparse_inputs_with_array_api_enabled(): + X_sp = sp.csr_array([[0, 1, 0], [1, 0, 1]]) + with config_context(array_api_dispatch=True): + assert sp.issparse(check_array(X_sp, accept_sparse=True)) + + with pytest.raises(TypeError): + check_array(X_sp) + + @pytest.mark.parametrize( "extension_dtype, regular_dtype", [ diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index af9fdb4a79cba..651951e8b90ca 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -838,7 +838,10 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html" ) - xp, is_array_api_compliant = get_namespace(array) + if sp.issparse(array): + xp, is_array_api_compliant = None, False + else: + xp, is_array_api_compliant = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -931,7 +934,7 @@ def is_sparse(dtype): ) ) - if dtype is not None and _is_numpy_namespace(xp): + if dtype is not None and xp is not None and _is_numpy_namespace(xp): # convert to dtype object to conform to Array API to be use `xp.isdtype` later dtype = np.dtype(dtype)