Skip to content

Commit 1ff785e

Browse files
StefanieSengervirchanogrisellesteveOmarManzoor
authored
ENH Array API support for confusion_matrix (#30562)
Co-authored-by: Virgil Chan <virchan.math@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent e890e6b commit 1ff785e

File tree

5 files changed

+88
-16
lines changed

5 files changed

+88
-16
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ Metrics
141141
-------
142142

143143
- :func:`sklearn.metrics.accuracy_score`
144+
- :func:`sklearn.metrics.confusion_matrix`
144145
- :func:`sklearn.metrics.d2_tweedie_score`
145146
- :func:`sklearn.metrics.explained_variance_score`
146147
- :func:`sklearn.metrics.f1_score`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs.
2+
By :user:`Stefanie Senger <StefanieSenger>`

sklearn/metrics/_classification.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.utils._array_api import (
3030
_average,
3131
_bincount,
32+
_convert_to_numpy,
3233
_count_nonzero,
3334
_find_matching_floating_dtype,
3435
_is_numpy_namespace,
@@ -413,7 +414,7 @@ def confusion_matrix(
413414
y_pred : array-like of shape (n_samples,)
414415
Estimated targets as returned by a classifier.
415416
416-
labels : array-like of shape (n_classes), default=None
417+
labels : array-like of shape (n_classes,), default=None
417418
List of labels to index the matrix. This may be used to reorder
418419
or select a subset of labels.
419420
If ``None`` is given, those that appear at least once
@@ -475,28 +476,61 @@ def confusion_matrix(
475476
>>> (tn, fp, fn, tp)
476477
(0, 2, 1, 1)
477478
"""
478-
y_true, y_pred = attach_unique(y_true, y_pred)
479-
y_type, y_true, y_pred, sample_weight = _check_targets(
480-
y_true, y_pred, sample_weight
479+
xp, _, device_ = get_namespace_and_device(y_true, y_pred, labels, sample_weight)
480+
y_true = check_array(
481+
y_true,
482+
dtype=None,
483+
ensure_2d=False,
484+
ensure_all_finite=False,
485+
ensure_min_samples=0,
481486
)
487+
y_pred = check_array(
488+
y_pred,
489+
dtype=None,
490+
ensure_2d=False,
491+
ensure_all_finite=False,
492+
ensure_min_samples=0,
493+
)
494+
# Convert the input arrays to NumPy (on CPU) irrespective of the original
495+
# namespace and device so as to be able to leverage the the efficient
496+
# counting operations implemented by SciPy in the coo_matrix constructor.
497+
# The final results will be converted back to the input namespace and device
498+
# for the sake of consistency with other metric functions with array API support.
499+
y_true = _convert_to_numpy(y_true, xp)
500+
y_pred = _convert_to_numpy(y_pred, xp)
501+
if sample_weight is None:
502+
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
503+
else:
504+
sample_weight = _convert_to_numpy(sample_weight, xp)
505+
506+
if len(sample_weight) > 0:
507+
y_type, y_true, y_pred, sample_weight = _check_targets(
508+
y_true, y_pred, sample_weight
509+
)
510+
else:
511+
# This is needed to handle the special case where y_true, y_pred and
512+
# sample_weight are all empty.
513+
# In this case we don't pass sample_weight to _check_targets that would
514+
# check that sample_weight is not empty and we don't reuse the returned
515+
# sample_weight
516+
y_type, y_true, y_pred, _ = _check_targets(y_true, y_pred)
517+
518+
y_true, y_pred = attach_unique(y_true, y_pred)
482519
if y_type not in ("binary", "multiclass"):
483520
raise ValueError("%s is not supported" % y_type)
484521

485522
if labels is None:
486523
labels = unique_labels(y_true, y_pred)
487524
else:
488-
labels = np.asarray(labels)
525+
labels = _convert_to_numpy(labels, xp)
489526
n_labels = labels.size
490527
if n_labels == 0:
491-
raise ValueError("'labels' should contains at least one label.")
528+
raise ValueError("'labels' should contain at least one label.")
492529
elif y_true.size == 0:
493530
return np.zeros((n_labels, n_labels), dtype=int)
494531
elif len(np.intersect1d(y_true, labels)) == 0:
495532
raise ValueError("At least one label specified must be in y_true")
496533

497-
if sample_weight is None:
498-
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
499-
500534
n_labels = labels.size
501535
# If labels are not consecutive integers starting from zero, then
502536
# y_true and y_pred must be converted into index form
@@ -507,9 +541,9 @@ def confusion_matrix(
507541
and y_pred.min() >= 0
508542
)
509543
if need_index_conversion:
510-
label_to_ind = {y: x for x, y in enumerate(labels)}
511-
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
512-
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
544+
label_to_ind = {label: index for index, label in enumerate(labels)}
545+
y_pred = np.array([label_to_ind.get(label, n_labels + 1) for label in y_pred])
546+
y_true = np.array([label_to_ind.get(label, n_labels + 1) for label in y_true])
513547

514548
# intersect y_pred, y_true with labels, eliminate items not in labels
515549
ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
@@ -550,7 +584,7 @@ def confusion_matrix(
550584
UserWarning,
551585
)
552586

553-
return cm
587+
return xp.asarray(cm, device=device_)
554588

555589

556590
@validate_params(

sklearn/metrics/tests/test_classification.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from scipy.stats import bernoulli
1111

1212
from sklearn import datasets, svm
13+
from sklearn.base import config_context
1314
from sklearn.datasets import make_multilabel_classification
1415
from sklearn.exceptions import UndefinedMetricWarning
1516
from sklearn.metrics import (
@@ -43,8 +44,16 @@
4344
from sklearn.model_selection import cross_val_score
4445
from sklearn.preprocessing import LabelBinarizer, label_binarize
4546
from sklearn.tree import DecisionTreeClassifier
47+
from sklearn.utils._array_api import (
48+
device as array_api_device,
49+
)
50+
from sklearn.utils._array_api import (
51+
get_namespace,
52+
yield_namespace_device_dtype_combinations,
53+
)
4654
from sklearn.utils._mocking import MockDataFrame
4755
from sklearn.utils._testing import (
56+
_array_api_for_tests,
4857
assert_allclose,
4958
assert_almost_equal,
5059
assert_array_almost_equal,
@@ -1269,7 +1278,7 @@ def test_confusion_matrix_multiclass_subset_labels():
12691278
@pytest.mark.parametrize(
12701279
"labels, err_msg",
12711280
[
1272-
([], "'labels' should contains at least one label."),
1281+
([], "'labels' should contain at least one label."),
12731282
([3, 4], "At least one label specified must be in y_true"),
12741283
],
12751284
ids=["empty list", "unknown labels"],
@@ -1283,10 +1292,14 @@ def test_confusion_matrix_error(labels, err_msg):
12831292
@pytest.mark.parametrize(
12841293
"labels", (None, [0, 1], [0, 1, 2]), ids=["None", "binary", "multiclass"]
12851294
)
1286-
def test_confusion_matrix_on_zero_length_input(labels):
1295+
@pytest.mark.parametrize(
1296+
"sample_weight",
1297+
(None, []),
1298+
)
1299+
def test_confusion_matrix_on_zero_length_input(labels, sample_weight):
12871300
expected_n_classes = len(labels) if labels else 0
12881301
expected = np.zeros((expected_n_classes, expected_n_classes), dtype=int)
1289-
cm = confusion_matrix([], [], labels=labels)
1302+
cm = confusion_matrix([], [], sample_weight=sample_weight, labels=labels)
12901303
assert_array_equal(cm, expected)
12911304

12921305

@@ -3608,3 +3621,21 @@ def test_d2_brier_score_warning_on_less_than_two_samples():
36083621
warning_message = "not well-defined with less than two samples"
36093622
with pytest.warns(UndefinedMetricWarning, match=warning_message):
36103623
d2_brier_score(y_true, y_pred)
3624+
3625+
3626+
@pytest.mark.parametrize(
3627+
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
3628+
)
3629+
def test_confusion_matrix_array_api(array_namespace, device, _):
3630+
"""Test that `confusion_matrix` works for all array types when `labels` are passed
3631+
such that the inner boolean `need_index_conversion` evaluates to `True`."""
3632+
xp = _array_api_for_tests(array_namespace, device)
3633+
3634+
y_true = xp.asarray([1, 2, 3], device=device)
3635+
y_pred = xp.asarray([4, 5, 6], device=device)
3636+
labels = xp.asarray([1, 2, 3], device=device)
3637+
3638+
with config_context(array_api_dispatch=True):
3639+
result = confusion_matrix(y_true, y_pred, labels=labels)
3640+
assert get_namespace(result)[0] == get_namespace(y_pred)[0]
3641+
assert array_api_device(result) == array_api_device(y_pred)

sklearn/metrics/tests/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22252225
check_array_api_multiclass_classification_metric,
22262226
check_array_api_multilabel_classification_metric,
22272227
],
2228+
confusion_matrix: [
2229+
check_array_api_binary_classification_metric,
2230+
check_array_api_multiclass_classification_metric,
2231+
],
22282232
f1_score: [
22292233
check_array_api_binary_classification_metric,
22302234
check_array_api_multiclass_classification_metric,

0 commit comments

Comments
 (0)