Skip to content

FIX make check_array accept sparse inputs when array api dispatch is enabled #29469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ See :ref:`array_api` for more details.
compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim
Head <betatim>` and :user:`Olivier Grisel <ogrisel>`.

- |Fix| func:`validation.check_array` now accepts scipy sparse inputs without error
even when array API dispatch is enabled.
:pr:`29469` by :user:`Olivier Grisel <ogrisel>`.

Metadata Routing
----------------

Expand Down
10 changes: 10 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,16 @@ def test_check_array_array_api_has_non_finite(array_namespace):
check_array(X_inf)


@skip_if_array_api_compat_not_configured
def test_check_array_on_sparse_inputs_with_array_api_enabled():
X_sp = sp.csr_array([[0, 1, 0], [1, 0, 1]])
with config_context(array_api_dispatch=True):
assert sp.issparse(check_array(X_sp, accept_sparse=True))

with pytest.raises(TypeError):
check_array(X_sp)


@pytest.mark.parametrize(
"extension_dtype, regular_dtype",
[
Expand Down
7 changes: 5 additions & 2 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,10 @@ def check_array(
"https://numpy.org/doc/stable/reference/generated/numpy.matrix.html"
)

xp, is_array_api_compliant = get_namespace(array)
if sp.issparse(array):
xp, is_array_api_compliant = None, False
else:
xp, is_array_api_compliant = get_namespace(array)

# store reference to original array to check if copy is needed when
# function returns
Expand Down Expand Up @@ -931,7 +934,7 @@ def is_sparse(dtype):
)
)

if dtype is not None and _is_numpy_namespace(xp):
if dtype is not None and xp is not None and _is_numpy_namespace(xp):
# convert to dtype object to conform to Array API to be use `xp.isdtype` later
dtype = np.dtype(dtype)

Expand Down