From 6138a92fb45369b651f2180486e7904f03de9aa2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Jul 2024 16:12:53 +0200 Subject: [PATCH 1/3] FIX make check_array accept sparse inputs when array api dispatch is enabled --- sklearn/utils/tests/test_validation.py | 10 ++++++++++ sklearn/utils/validation.py | 7 +++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 5bde51ae514d9..afc3d743a95d6 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2018,6 +2018,16 @@ def test_check_array_array_api_has_non_finite(array_namespace): check_array(X_inf) +@skip_if_array_api_compat_not_configured +def test_check_array_on_sparse_inputs_with_array_api_enabled(): + X_sp = sp.csr_array([[0, 1, 0], [1, 0, 1]]) + with config_context(array_api_dispatch=True): + assert sp.issparse(check_array(X_sp, accept_sparse=True)) + + with pytest.raises(TypeError): + check_array(X_sp) + + @pytest.mark.parametrize( "extension_dtype, regular_dtype", [ diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index af9fdb4a79cba..651951e8b90ca 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -838,7 +838,10 @@ def check_array( "https://numpy.org/doc/stable/reference/generated/numpy.matrix.html" ) - xp, is_array_api_compliant = get_namespace(array) + if sp.issparse(array): + xp, is_array_api_compliant = None, False + else: + xp, is_array_api_compliant = get_namespace(array) # store reference to original array to check if copy is needed when # function returns @@ -931,7 +934,7 @@ def is_sparse(dtype): ) ) - if dtype is not None and _is_numpy_namespace(xp): + if dtype is not None and xp is not None and _is_numpy_namespace(xp): # convert to dtype object to conform to Array API to be use `xp.isdtype` later dtype = np.dtype(dtype) From 39760c84bd4360ca63ee45a4785d8564d1342cfd Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Jul 2024 16:39:24 +0200 Subject: [PATCH 2/3] Add a changelog entry --- doc/whats_new/v1.6.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index d7d3a71eba636..fb3563736577e 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -56,6 +56,10 @@ See :ref:`array_api` for more details. compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim Head ` and :user:`Olivier Grisel `. +- |Fix| func:`validation.check_array` now accepts scipy sparse inputs without error + even when array API dispatch is enabled. + :pr:`29469` by :user:`Olivier Grisel `. + Metadata Routing ---------------- From cf2170569de17019d5cfdf5c63f122182d38940a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 11 Jul 2024 16:58:52 +0200 Subject: [PATCH 3/3] Investigate impact of rejecting sparse inputs in get_namespace --- sklearn/utils/_array_api.py | 8 ++++++++ sklearn/utils/validation.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index a00d250ab31d2..725d43e8afb95 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -5,6 +5,7 @@ from functools import wraps import numpy +import scipy.sparse as sp import scipy.special as special from .._config import get_config @@ -528,6 +529,13 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): True if the arrays are containers that implement the Array API spec. Always False when array_api_dispatch=False. """ + if any(sp.issparse(a) for a in arrays): + # Consistently reject scipy sparse arrays, whether or not array_api_dispatch + # is enabled. + raise ValueError( + "Scipy sparse arrays or matrices are not supported in get_namespace." + ) + array_api_dispatch = get_config()["array_api_dispatch"] if not array_api_dispatch: if xp is not None: diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 651951e8b90ca..3893b5947a9fb 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1670,7 +1670,6 @@ def check_non_negative(X, whom): whom : str Who passed X to this function. """ - xp, _ = get_namespace(X) # avoid X.min() on sparse matrix since it also sorts the indices if sp.issparse(X): if X.format in ["lil", "dok"]: @@ -1680,6 +1679,7 @@ def check_non_negative(X, whom): else: X_min = X.data.min() else: + xp, _ = get_namespace(X) X_min = xp.min(X) if X_min < 0: