Skip to content

ENH Add Array API compatibility to zero_one_loss and accuracy_score #27137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 7, 2023
6 changes: 6 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ Estimators
- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`

Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.zero_one_loss`

Tools
-----

Expand Down
5 changes: 4 additions & 1 deletion doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ Changelog
:pr:`26931` by `Thomas Fan`_.

- |MajorFeature| :class:`preprocessing.MinMaxScaler` and :class:`preprocessing.MaxAbsScaler` now
supports the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
support the `Array API <https://data-apis.org/array-api/latest/>`_. Array API
support is considered experimental and might evolve without being subject to
our usual rolling deprecation cycle policy. See
:ref:`array_api` for more details. :pr:`26243` by `Tim Head`_ and :pr:`27110` by :user:`Edoardo Abati <EdAbati>`.
Expand Down Expand Up @@ -264,6 +264,9 @@ Changelog
both axis is set to be 1 to get a square plot.
:pr:`26366` by :user:`Mojdeh Rastgoo <mrastgoo>`.

- |Enhancement| :func:`sklearn.metrics.accuracy_score` and :func:`sklearn.metrics.zero_one_loss` now support
Array API compatible inputs. :pr:`27137` by :user:`Edoardo Abati <EdAbati>`.

:mod:`sklearn.utils`
....................

Expand Down
8 changes: 4 additions & 4 deletions sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from sklearn.decomposition import PCA
from sklearn.decomposition._pca import _assess_dimension, _infer_dimension
from sklearn.utils._array_api import (
_atol_for_type,
_convert_to_numpy,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import _array_api_for_tests, assert_allclose
from sklearn.utils.estimator_checks import (
_array_api_for_tests,
_get_check_estimator_ids,
check_array_api_input_and_values,
)
Expand Down Expand Up @@ -717,7 +717,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp
assert_allclose(
_convert_to_numpy(precision_xp, xp=xp),
precision_np,
atol=np.finfo(dtype).eps * 100,
atol=_atol_for_type(dtype),
)
covariance_xp = estimator_xp.get_covariance()
assert covariance_xp.shape == (4, 4)
Expand All @@ -726,7 +726,7 @@ def check_array_api_get_precision(name, estimator, array_namepsace, device, dtyp
assert_allclose(
_convert_to_numpy(covariance_xp, xp=xp),
covariance_np,
atol=np.finfo(dtype).eps * 100,
atol=_atol_for_type(dtype),
)


Expand Down
3 changes: 2 additions & 1 deletion sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None):
>>> zero_one_loss(np.array([[0, 1], [1, 1]]), np.ones((2, 2)))
0.5
"""
xp, _ = get_namespace(y_true, y_pred)
score = accuracy_score(
y_true, y_pred, normalize=normalize, sample_weight=sample_weight
)
Expand All @@ -1054,7 +1055,7 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None):
return 1 - score
else:
if sample_weight is not None:
n_samples = np.sum(sample_weight)
n_samples = xp.sum(sample_weight)
else:
n_samples = _num_samples(y_true)
return n_samples - score
Expand Down
77 changes: 77 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import scipy.sparse as sp

from sklearn._config import config_context
from sklearn.datasets import make_multilabel_classification
from sklearn.metrics import (
accuracy_score,
Expand Down Expand Up @@ -53,7 +54,12 @@
from sklearn.metrics._base import _average_binary_score
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
from sklearn.utils._array_api import (
_atol_for_type,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
_array_api_for_tests,
assert_allclose,
assert_almost_equal,
assert_array_equal,
Expand Down Expand Up @@ -1723,3 +1729,74 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
err_msg = err_msg_pos_label_1 if pos_label_default == 1 else err_msg_pos_label_None
with pytest.raises(ValueError, match=err_msg):
metric(y1, y2)


def check_array_api_metric(
metric, array_namespace, device, dtype, y_true_np, y_pred_np
):
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
y_true_xp = xp.asarray(y_true_np, device=device)
y_pred_xp = xp.asarray(y_pred_np, device=device)

metric_np = metric(y_true_np, y_pred_np)

with config_context(array_api_dispatch=True):
metric_xp = metric(y_true_xp, y_pred_xp)

assert_allclose(
metric_xp,
metric_np,
atol=_atol_for_type(dtype),
)


def check_array_api_binary_classification_metric(
metric, array_namespace, device, dtype
):
return check_array_api_metric(
metric,
array_namespace,
device,
dtype,
y_true_np=np.array([0, 0, 1, 1]),
y_pred_np=np.array([0, 1, 0, 1]),
)


def check_array_api_multiclass_classification_metric(
metric, array_namespace, device, dtype
):
return check_array_api_metric(
metric,
array_namespace,
device,
dtype,
y_true_np=np.array([0, 1, 2, 3]),
y_pred_np=np.array([0, 1, 0, 2]),
)


metric_checkers = {
accuracy_score: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
zero_one_loss: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
}


def yield_metric_checker_combinations(metric_checkers=metric_checkers):
for metric, checkers in metric_checkers.items():
for checker in checkers:
yield metric, checker


@pytest.mark.parametrize(
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
def test_array_api_compliance(metric, array_namespace, device, dtype, check_func):
check_func(metric, array_namespace, device, dtype)
7 changes: 6 additions & 1 deletion sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)

if sample_weight is not None:
sample_weight = xp.asarray(sample_weight)
sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
if not xp.isdtype(sample_weight.dtype, "real floating"):
sample_weight = xp.astype(sample_weight, xp.float64)

Expand Down Expand Up @@ -590,3 +590,8 @@ def _estimator_with_converted_arrays(estimator, converter):
attribute = converter(attribute)
setattr(new_estimator, key, attribute)
return new_estimator


def _atol_for_type(dtype):
"""Return the absolute tolerance for a given dtype."""
return numpy.finfo(dtype).eps * 100
41 changes: 41 additions & 0 deletions sklearn/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import atexit
import contextlib
import functools
import importlib
import inspect
import os
import os.path as op
Expand Down Expand Up @@ -1047,3 +1048,43 @@ def transform(self, X, y=None):

def fit_transform(self, X, y=None):
return self.fit(X, y).transform(X, y)


def _array_api_for_tests(array_namespace, device, dtype):
try:
array_mod = importlib.import_module(array_namespace)
except ModuleNotFoundError:
raise SkipTest(
f"{array_namespace} is not installed: not checking array_api input"
)
try:
import array_api_compat # noqa
except ImportError:
raise SkipTest(
"array_api_compat is not installed: not checking array_api input"
)

# First create an array using the chosen array module and then get the
# corresponding (compatibility wrapped) array namespace based on it.
# This is because `cupy` is not the same as the compatibility wrapped
# namespace of a CuPy array.
xp = array_api_compat.get_namespace(array_mod.asarray(1))
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
raise SkipTest("PyTorch test requires cuda, which is not available")
elif array_namespace == "torch" and device == "mps" and not xp.has_mps:
if not xp.backends.mps.is_built():
raise SkipTest(
"MPS is not available because the current PyTorch install was not "
"built with MPS enabled."
)
else:
raise SkipTest(
"MPS is not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
import cupy

if cupy.cuda.runtime.getDeviceCount() == 0:
raise SkipTest("CuPy test requires cuda, which is not available")
return xp, device, dtype
42 changes: 1 addition & 41 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import pickle
import re
import warnings
Expand Down Expand Up @@ -67,6 +66,7 @@
)
from ._testing import (
SkipTest,
_array_api_for_tests,
_get_args,
assert_allclose,
assert_allclose_dense_sparse,
Expand Down Expand Up @@ -849,46 +849,6 @@ def _generate_sparse_matrix(X_csr):
yield sparse_format + "_64", X


def _array_api_for_tests(array_namespace, device, dtype):
try:
array_mod = importlib.import_module(array_namespace)
except ModuleNotFoundError:
raise SkipTest(
f"{array_namespace} is not installed: not checking array_api input"
)
try:
import array_api_compat # noqa
except ImportError:
raise SkipTest(
"array_api_compat is not installed: not checking array_api input"
)

# First create an array using the chosen array module and then get the
# corresponding (compatibility wrapped) array namespace based on it.
# This is because `cupy` is not the same as the compatibility wrapped
# namespace of a CuPy array.
xp = array_api_compat.get_namespace(array_mod.asarray(1))
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
raise SkipTest("PyTorch test requires cuda, which is not available")
elif array_namespace == "torch" and device == "mps" and not xp.has_mps:
if not xp.backends.mps.is_built():
raise SkipTest(
"MPS is not available because the current PyTorch install was not "
"built with MPS enabled."
)
else:
raise SkipTest(
"MPS is not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
import cupy

if cupy.cuda.runtime.getDeviceCount() == 0:
raise SkipTest("CuPy test requires cuda, which is not available")
return xp, device, dtype


def check_array_api_input(
name,
estimator_orig,
Expand Down
39 changes: 38 additions & 1 deletion sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@
from sklearn.utils._array_api import (
_ArrayAPIWrapper,
_asarray_with_order,
_atol_for_type,
_convert_to_numpy,
_estimator_with_converted_arrays,
_nanmax,
_nanmin,
_NumPyAPIWrapper,
_weighted_sum,
get_namespace,
supported_float_dtypes,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
_array_api_for_tests,
skip_if_array_api_compat_not_configured,
)
from sklearn.utils._testing import skip_if_array_api_compat_not_configured

pytestmark = pytest.mark.filterwarnings(
"ignore:The numpy.array_api submodule:UserWarning"
Expand Down Expand Up @@ -164,6 +170,37 @@ def test_asarray_with_order_ignored():
assert not X_new_np.flags["F_CONTIGUOUS"]


@pytest.mark.parametrize(
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize(
"sample_weight, normalize, expected",
[
(None, False, 10.0),
(None, True, 2.5),
([0.4, 0.4, 0.5, 0.7], False, 5.5),
([0.4, 0.4, 0.5, 0.7], True, 2.75),
([1, 2, 3, 4], False, 30.0),
([1, 2, 3, 4], True, 3.0),
],
)
def test_weighted_sum(
array_namespace, device, dtype, sample_weight, normalize, expected
):
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
sample_score = numpy.asarray([1, 2, 3, 4], dtype=dtype)
sample_score = xp.asarray(sample_score, device=device)
if sample_weight is not None:
sample_weight = numpy.asarray(sample_weight, dtype=dtype)
sample_weight = xp.asarray(sample_weight, device=device)

with config_context(array_api_dispatch=True):
result = _weighted_sum(sample_score, sample_weight, normalize)

assert isinstance(result, float)
assert_allclose(result, expected, atol=_atol_for_type(dtype))


@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize(
"library", ["numpy", "numpy.array_api", "cupy", "cupy.array_api", "torch"]
Expand Down