From 6138a92fb45369b651f2180486e7904f03de9aa2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Jul 2024 16:12:53 +0200 Subject: [PATCH 1/2] FIX make check_array accept sparse inputs when array api dispatch is enabled --- sklearn/utils/tests/test_validation.py | 10 ++++++++++ sklearn/utils/validation.py | 7 +++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 5bde51ae514d9..afc3d743a95d6 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2018,6 +2018,16 @@ def test_check_array_array_api_has_non_finite(array_namespace): check_array(X_inf) +@skip_if_array_api_compat_not_configured +def test_check_array_on_sparse_inputs_with_array_api_enabled(): + X_sp = sp.csr_array([[0, 1, 0], [1, 0, 1]]) + with config_context(array_api_dispatch=True): + assert sp.issparse(check_array(X_sp, accept_sparse=True)) + + with pytest.raises(TypeError): + check_array(X_sp) + + @pytest.mark.parametrize( "extension_dtype, regular_dtype", [ diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index af9fdb4a79cba..651951e8b90ca 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -838,7 +838,10 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html" ) - xp, is_array_api_compliant = get_namespace(array) + if sp.issparse(array): + xp, is_array_api_compliant = None, False + else: + xp, is_array_api_compliant = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -931,7 +934,7 @@ def is_sparse(dtype): ) ) - if dtype is not None and _is_numpy_namespace(xp): + if dtype is not None and xp is not None and _is_numpy_namespace(xp): # convert to dtype object to conform to Array API to be use `xp.isdtype` later dtype = np.dtype(dtype) From 39760c84bd4360ca63ee45a4785d8564d1342cfd Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Jul 2024 16:39:24 +0200 Subject: [PATCH 2/2] Add a changelog entry --- doc/whats_new/v1.6.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index d7d3a71eba636..fb3563736577e 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -56,6 +56,10 @@ See :ref:`array_api` for more details. compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim Head ` and :user:`Olivier Grisel `. +- |Fix| func:`validation.check_array` now accepts scipy sparse inputs without error + even when array API dispatch is enabled. + :pr:`29469` by :user:`Olivier Grisel `. + Metadata Routing ----------------