Skip to content

Fix mixed dense/sparse array API namespace inspection #29466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ See :ref:`array_api` for more details.
compatible inputs when their base estimators do. :pr:`27096` by :user:`Tim
Head <betatim>` and :user:`Olivier Grisel <ogrisel>`.

Other changes impacting array API support in general:

- |Fix| func:`validation.check_array` now accepts scipy sparse inputs without error
even when array API dispatch is enabled.
:pr:`29466` by :user:`Olivier Grisel <ogrisel>`.

Metadata Routing
----------------

Expand Down
14 changes: 14 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import wraps

import numpy
import scipy.sparse as sp
import scipy.special as special

from .._config import get_config
Expand Down Expand Up @@ -550,6 +551,19 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
# message in case it is missing.
import array_api_compat

# Special handling of scipy sparse inputs, potentially mixed with dense
# arrays.
if all(sp.issparse(a) for a in arrays):
# For all-scipy sparse inputs, it's safe to assume that the namespace
# is numpy (or its array-api-compat wrapper for versions before 2.0).
return array_api_compat.get_namespace(numpy.empty(shape=0)), True
else:
# For mixed dense/sparse array inputs, ignore the sparse arrays and
# proceed only with the dense ones. The caller code in scikit-learn
# should always be in charge of special handling scipy sparse
# operations under sp.issparse() branches when needed.
arrays = [a for a in arrays if not sp.issparse(a)]

namespace, is_array_api_compliant = array_api_compat.get_namespace(*arrays), True

# These namespaces need additional wrapping to smooth out small differences
Expand Down
34 changes: 27 additions & 7 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy
import pytest
import scipy.sparse as sp
from numpy.testing import assert_allclose

from sklearn._config import config_context
Expand Down Expand Up @@ -60,19 +61,31 @@ def test_get_namespace_ndarray_creation_device():
@skip_if_array_api_compat_not_configured
def test_get_namespace_ndarray_with_dispatch():
"""Test get_namespace on NumPy ndarrays."""
array_api_compat = pytest.importorskip("array_api_compat")
array_api_compat_numpy = pytest.importorskip("array_api_compat.numpy")

X_np = numpy.asarray([[1, 2, 3]])
if np_version >= parse_version("2.0.0"):
# NumPy 2.0+ is an array API compliant library.
expected_namespace = numpy
else:
# Older NumPy versions require the compatibility layer.
expected_namespace = array_api_compat_numpy

with config_context(array_api_dispatch=True):
xp_out, is_array_api_compliant = get_namespace(X_np)
assert is_array_api_compliant
if np_version >= parse_version("2.0.0"):
# NumPy 2.0+ is an array API compliant library.
assert xp_out is numpy
else:
# Older NumPy versions require the compatibility layer.
assert xp_out is array_api_compat.numpy
assert xp_out is expected_namespace

X_sp = sp.csr_array(X_np)
with config_context(array_api_dispatch=True):
xp_out, is_array_api_compliant = get_namespace(X_sp)
assert is_array_api_compliant
assert xp_out is expected_namespace

with config_context(array_api_dispatch=True):
xp_out, is_array_api_compliant = get_namespace(X_np, X_sp)
assert is_array_api_compliant
assert xp_out is expected_namespace


@skip_if_array_api_compat_not_configured
Expand All @@ -82,13 +95,20 @@ def test_get_namespace_array_api():

X_np = numpy.asarray([[1, 2, 3]])
X_xp = xp.asarray(X_np)
X_sp = sp.csr_array(X_np)

with config_context(array_api_dispatch=True):
xp_out, is_array_api_compliant = get_namespace(X_xp)
assert is_array_api_compliant
assert xp_out is xp

with pytest.raises(TypeError):
xp_out, is_array_api_compliant = get_namespace(X_xp, X_np)

xp_out, is_array_api_compliant = get_namespace(X_xp, X_sp)
assert is_array_api_compliant
assert xp_out is xp


class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
"""API wrapper that has an adjustable name. Used for testing."""
Expand Down
10 changes: 10 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2018,6 +2018,16 @@ def test_check_array_array_api_has_non_finite(array_namespace):
check_array(X_inf)


@skip_if_array_api_compat_not_configured
def test_check_array_on_sparse_inputs_with_array_api_enabled():
X_sp = sp.csr_array([[0, 1, 0], [1, 0, 1]])
with config_context(array_api_dispatch=True):
assert sp.issparse(check_array(X_sp, accept_sparse=True))

with pytest.raises(TypeError):
check_array(X_sp)


@pytest.mark.parametrize(
"extension_dtype, regular_dtype",
[
Expand Down
Loading