Skip to content

ENH Array API support for confusion_matrix converting to numpy array #30562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
78d2a65
ENH Array API for confusion_matrix
StefanieSenger Dec 7, 2024
770e638
fix dtype checking
StefanieSenger Dec 8, 2024
af440ca
prepare for PR
StefanieSenger Dec 9, 2024
b45646e
change log
StefanieSenger Dec 9, 2024
3db7054
use our _isin
StefanieSenger Dec 9, 2024
abab5ea
changes after review
StefanieSenger Dec 10, 2024
abc3981
forgot to push that before
StefanieSenger Dec 10, 2024
09cec5d
add test
StefanieSenger Dec 11, 2024
fdb25f6
fix sclar dtype
StefanieSenger Dec 12, 2024
49f75b7
fix typos
StefanieSenger Dec 12, 2024
914bb63
convert_to_numpy and coo_matrix instead of python loop
StefanieSenger Dec 18, 2024
a939c80
Merge branch 'main' into array_api_confusion_matrix
StefanieSenger Dec 23, 2024
6da1d06
experiment with convert_to_numpy
StefanieSenger Dec 23, 2024
1f23f63
np.intersect1d can stay as it is
StefanieSenger Dec 23, 2024
6a43bc3
return cm as numpy array
StefanieSenger Dec 30, 2024
2000a00
move attach unique to after conversion to numpy
StefanieSenger Dec 30, 2024
5963e0f
adjust test
StefanieSenger Dec 30, 2024
ef84e04
document return array type
StefanieSenger Dec 30, 2024
1cf525e
use get_namespace
StefanieSenger Jan 2, 2025
f50f3ea
fix issue with nullable dtypes with pandas==1.1.5
StefanieSenger Jan 2, 2025
84038e4
private function
StefanieSenger Jan 2, 2025
b47fdc7
Update sklearn/metrics/_classification.py
StefanieSenger Jan 3, 2025
1abc308
fix tests when pandas not installed
StefanieSenger Jan 3, 2025
7325cdf
better fix for environments without pandas
StefanieSenger Jan 3, 2025
ba06676
remove _nan_to_num
StefanieSenger Jan 9, 2025
3d9d986
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Jan 13, 2025
5b21ad6
remove handling for pandas < 1.2.0
StefanieSenger Jan 13, 2025
6776a11
remove handling of pandas if pandas not installed
StefanieSenger Jan 13, 2025
fa7564d
use check_array for handling pandas extension dtypes
StefanieSenger Jan 30, 2025
6ee6afc
ensure_all_finite=False
StefanieSenger Jan 30, 2025
32ea61e
add label passing to test to archive CodeCov
StefanieSenger Jan 31, 2025
e59cd7f
fix naming
StefanieSenger Jan 31, 2025
5124f98
experiment - need to push so I can test on GPU
StefanieSenger Jan 31, 2025
5034d01
convert labels to numpy
StefanieSenger Jan 31, 2025
3cffcbf
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Jan 31, 2025
d42caa6
Update sklearn/metrics/tests/test_classification.py
StefanieSenger Feb 10, 2025
094ca6d
remove unhelpful comment
StefanieSenger Feb 10, 2025
869f568
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ Metrics

- :func:`sklearn.metrics.cluster.entropy`
- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.confusion_matrix`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30562.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs.
By :user:`Stefanie Senger <StefanieSenger>`
35 changes: 28 additions & 7 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..utils._array_api import (
_average,
_bincount,
_convert_to_numpy,
_count_nonzero,
_find_matching_floating_dtype,
_is_numpy_namespace,
Expand Down Expand Up @@ -275,7 +276,7 @@ def confusion_matrix(
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.

labels : array-like of shape (n_classes), default=None
labels : array-like of shape (n_classes,), default=None
List of labels to index the matrix. This may be used to reorder
or select a subset of labels.
If ``None`` is given, those that appear at least once
Expand Down Expand Up @@ -337,6 +338,23 @@ def confusion_matrix(
>>> (tn, fp, fn, tp)
(0, 2, 1, 1)
"""
xp, _ = get_namespace(y_true, y_pred, labels, sample_weight)
y_true = check_array(
y_true,
dtype=None,
ensure_2d=False,
ensure_all_finite=False,
ensure_min_samples=0,
)
y_pred = check_array(
y_pred,
dtype=None,
ensure_2d=False,
ensure_all_finite=False,
ensure_min_samples=0,
)
y_true = _convert_to_numpy(y_true, xp)
y_pred = _convert_to_numpy(y_pred, xp)
y_true, y_pred = attach_unique(y_true, y_pred)
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if y_type not in ("binary", "multiclass"):
Expand All @@ -345,10 +363,13 @@ def confusion_matrix(
if labels is None:
labels = unique_labels(y_true, y_pred)
else:
labels = np.asarray(labels)
if not _is_numpy_namespace(get_namespace(labels)[0]):
labels = _convert_to_numpy(labels, xp)
else:
labels = np.asarray(labels)
Comment on lines +366 to +369
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can replace these with a single line:

labels = _convert_to_numpy(labels, xp)

n_labels = labels.size
if n_labels == 0:
raise ValueError("'labels' should contains at least one label.")
raise ValueError("'labels' should contain at least one label.")
elif y_true.size == 0:
return np.zeros((n_labels, n_labels), dtype=int)
elif len(np.intersect1d(y_true, labels)) == 0:
Expand All @@ -357,7 +378,7 @@ def confusion_matrix(
if sample_weight is None:
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
else:
sample_weight = np.asarray(sample_weight)
sample_weight = _convert_to_numpy(sample_weight, xp)

check_consistent_length(y_true, y_pred, sample_weight)

Expand All @@ -371,9 +392,9 @@ def confusion_matrix(
and y_pred.min() >= 0
)
if need_index_conversion:
label_to_ind = {y: x for x, y in enumerate(labels)}
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
label_to_ind = {label: index for index, label in enumerate(labels)}
y_pred = np.array([label_to_ind.get(label, n_labels + 1) for label in y_pred])
y_true = np.array([label_to_ind.get(label, n_labels + 1) for label in y_true])

# intersect y_pred, y_true with labels, eliminate items not in labels
ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
Expand Down
32 changes: 31 additions & 1 deletion sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import bernoulli

from sklearn import datasets, svm
from sklearn.base import config_context
from sklearn.datasets import make_multilabel_classification
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (
Expand Down Expand Up @@ -39,8 +40,14 @@
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._array_api import (
_is_numpy_namespace,
get_namespace,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._testing import (
_array_api_for_tests,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
Expand Down Expand Up @@ -1273,7 +1280,7 @@ def test_confusion_matrix_multiclass_subset_labels():
@pytest.mark.parametrize(
"labels, err_msg",
[
([], "'labels' should contains at least one label."),
([], "'labels' should contain at least one label."),
([3, 4], "At least one label specified must be in y_true"),
],
ids=["empty list", "unknown labels"],
Expand Down Expand Up @@ -3226,3 +3233,26 @@ def test_d2_log_loss_score_raises():
err = "The labels array needs to contain at least two"
with pytest.raises(ValueError, match=err):
d2_log_loss_score(y_true, y_pred, labels=labels)


@pytest.mark.parametrize(
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
)
def test_confusion_matrix_array_api(array_namespace, device, _):
"""Test that `confusion_matrix` works for all array types when `labels` are passed
such that the inner boolean `need_index_conversion` evaluates to `True`."""
xp = _array_api_for_tests(array_namespace, device)

y_true = xp.asarray([1, 2, 3], device=device)
y_pred = xp.asarray([4, 5, 6], device=device)
labels = xp.asarray([1, 2, 3], device=device)

with config_context(array_api_dispatch=True):
result = confusion_matrix(y_true, y_pred, labels=labels)
xp_result, _ = get_namespace(result)
assert _is_numpy_namespace(xp_result)

# Since the computation always happens with NumPy / SciPy on the CPU, this
# function is expected to return an array allocated on the CPU even when it does
# not match the input array's device.
assert result.device == "cpu"
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_multiclass_classification_metric,
check_array_api_multilabel_classification_metric,
],
confusion_matrix: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
f1_score: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def _ravel(array, xp=None):


def _convert_to_numpy(array, xp):
"""Convert X into a NumPy ndarray on the CPU."""
"""Convert array into a NumPy ndarray on the CPU."""
xp_name = xp.__name__

if xp_name in {"array_api_compat.torch", "torch"}:
Expand Down