Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
78d2a65
ENH Array API for confusion_matrix
StefanieSenger Dec 7, 2024
770e638
fix dtype checking
StefanieSenger Dec 8, 2024
af440ca
prepare for PR
StefanieSenger Dec 9, 2024
b45646e
change log
StefanieSenger Dec 9, 2024
3db7054
use our _isin
StefanieSenger Dec 9, 2024
abab5ea
changes after review
StefanieSenger Dec 10, 2024
abc3981
forgot to push that before
StefanieSenger Dec 10, 2024
09cec5d
add test
StefanieSenger Dec 11, 2024
fdb25f6
fix sclar dtype
StefanieSenger Dec 12, 2024
49f75b7
fix typos
StefanieSenger Dec 12, 2024
914bb63
convert_to_numpy and coo_matrix instead of python loop
StefanieSenger Dec 18, 2024
a939c80
Merge branch 'main' into array_api_confusion_matrix
StefanieSenger Dec 23, 2024
6da1d06
experiment with convert_to_numpy
StefanieSenger Dec 23, 2024
1f23f63
np.intersect1d can stay as it is
StefanieSenger Dec 23, 2024
6a43bc3
return cm as numpy array
StefanieSenger Dec 30, 2024
2000a00
move attach unique to after conversion to numpy
StefanieSenger Dec 30, 2024
5963e0f
adjust test
StefanieSenger Dec 30, 2024
ef84e04
document return array type
StefanieSenger Dec 30, 2024
1cf525e
use get_namespace
StefanieSenger Jan 2, 2025
f50f3ea
fix issue with nullable dtypes with pandas==1.1.5
StefanieSenger Jan 2, 2025
84038e4
private function
StefanieSenger Jan 2, 2025
b47fdc7
Update sklearn/metrics/_classification.py
StefanieSenger Jan 3, 2025
1abc308
fix tests when pandas not installed
StefanieSenger Jan 3, 2025
7325cdf
better fix for environments without pandas
StefanieSenger Jan 3, 2025
ba06676
remove _nan_to_num
StefanieSenger Jan 9, 2025
3d9d986
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Jan 13, 2025
5b21ad6
remove handling for pandas < 1.2.0
StefanieSenger Jan 13, 2025
6776a11
remove handling of pandas if pandas not installed
StefanieSenger Jan 13, 2025
fa7564d
use check_array for handling pandas extension dtypes
StefanieSenger Jan 30, 2025
6ee6afc
ensure_all_finite=False
StefanieSenger Jan 30, 2025
32ea61e
add label passing to test to archive CodeCov
StefanieSenger Jan 31, 2025
e59cd7f
fix naming
StefanieSenger Jan 31, 2025
5124f98
experiment - need to push so I can test on GPU
StefanieSenger Jan 31, 2025
5034d01
convert labels to numpy
StefanieSenger Jan 31, 2025
3cffcbf
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Jan 31, 2025
d42caa6
Update sklearn/metrics/tests/test_classification.py
StefanieSenger Feb 10, 2025
094ca6d
remove unhelpful comment
StefanieSenger Feb 10, 2025
869f568
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Feb 10, 2025
6d23419
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Jul 15, 2025
2edcf8e
convert back to original namespace and keep cpu device
StefanieSenger Jul 15, 2025
a5bf169
changes after review
StefanieSenger Jul 16, 2025
725cb24
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Jul 17, 2025
0740d3d
Update sklearn/metrics/_classification.py
StefanieSenger Aug 1, 2025
cc3fd4d
Merge branch 'main' into array_api_confusion_matrix_numpy
StefanieSenger Aug 1, 2025
cda414c
fix linting
StefanieSenger Aug 1, 2025
ea480bd
Adapt after _check_targets change [azure parallel]
lesteve Aug 1, 2025
a84671c
empty commit to re-trigger CI
StefanieSenger Aug 1, 2025
82bf221
empty commit to re-trigger CI [azure parallel]
StefanieSenger Aug 1, 2025
d1b3439
Tackle special case of empty inputs with _check_targets recent change…
lesteve Aug 1, 2025
5b103e7
Merge branch 'main' into array_api_confusion_matrix_numpy
OmarManzoor Aug 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ Metrics
-------

- :func:`sklearn.metrics.accuracy_score`
- :func:`sklearn.metrics.confusion_matrix`
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30562.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs.
By :user:`Stefanie Senger <StefanieSenger>`
60 changes: 47 additions & 13 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from sklearn.utils._array_api import (
_average,
_bincount,
_convert_to_numpy,
_count_nonzero,
_find_matching_floating_dtype,
_is_numpy_namespace,
Expand Down Expand Up @@ -413,7 +414,7 @@ def confusion_matrix(
y_pred : array-like of shape (n_samples,)
Estimated targets as returned by a classifier.

labels : array-like of shape (n_classes), default=None
labels : array-like of shape (n_classes,), default=None
List of labels to index the matrix. This may be used to reorder
or select a subset of labels.
If ``None`` is given, those that appear at least once
Expand Down Expand Up @@ -475,28 +476,61 @@ def confusion_matrix(
>>> (tn, fp, fn, tp)
(0, 2, 1, 1)
"""
y_true, y_pred = attach_unique(y_true, y_pred)
y_type, y_true, y_pred, sample_weight = _check_targets(
y_true, y_pred, sample_weight
xp, _, device_ = get_namespace_and_device(y_true, y_pred, labels, sample_weight)
y_true = check_array(
y_true,
dtype=None,
ensure_2d=False,
ensure_all_finite=False,
ensure_min_samples=0,
)
y_pred = check_array(
y_pred,
dtype=None,
ensure_2d=False,
ensure_all_finite=False,
ensure_min_samples=0,
)
# Convert the input arrays to NumPy (on CPU) irrespective of the original
# namespace and device so as to be able to leverage the the efficient
# counting operations implemented by SciPy in the coo_matrix constructor.
# The final results will be converted back to the input namespace and device
# for the sake of consistency with other metric functions with array API support.
y_true = _convert_to_numpy(y_true, xp)
y_pred = _convert_to_numpy(y_pred, xp)
if sample_weight is None:
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
else:
sample_weight = _convert_to_numpy(sample_weight, xp)

if len(sample_weight) > 0:
y_type, y_true, y_pred, sample_weight = _check_targets(
y_true, y_pred, sample_weight
)
else:
# This is needed to handle the special case where y_true, y_pred and
# sample_weight are all empty.
# In this case we don't pass sample_weight to _check_targets that would
# check that sample_weight is not empty and we don't reuse the returned
# sample_weight
y_type, y_true, y_pred, _ = _check_targets(y_true, y_pred)

y_true, y_pred = attach_unique(y_true, y_pred)
if y_type not in ("binary", "multiclass"):
raise ValueError("%s is not supported" % y_type)

if labels is None:
labels = unique_labels(y_true, y_pred)
else:
labels = np.asarray(labels)
labels = _convert_to_numpy(labels, xp)
n_labels = labels.size
if n_labels == 0:
raise ValueError("'labels' should contains at least one label.")
raise ValueError("'labels' should contain at least one label.")
elif y_true.size == 0:
return np.zeros((n_labels, n_labels), dtype=int)
elif len(np.intersect1d(y_true, labels)) == 0:
raise ValueError("At least one label specified must be in y_true")

if sample_weight is None:
sample_weight = np.ones(y_true.shape[0], dtype=np.int64)

n_labels = labels.size
# If labels are not consecutive integers starting from zero, then
# y_true and y_pred must be converted into index form
Expand All @@ -507,9 +541,9 @@ def confusion_matrix(
and y_pred.min() >= 0
)
if need_index_conversion:
label_to_ind = {y: x for x, y in enumerate(labels)}
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
label_to_ind = {label: index for index, label in enumerate(labels)}
y_pred = np.array([label_to_ind.get(label, n_labels + 1) for label in y_pred])
y_true = np.array([label_to_ind.get(label, n_labels + 1) for label in y_true])

# intersect y_pred, y_true with labels, eliminate items not in labels
ind = np.logical_and(y_pred < n_labels, y_true < n_labels)
Expand Down Expand Up @@ -550,7 +584,7 @@ def confusion_matrix(
UserWarning,
)

return cm
return xp.asarray(cm, device=device_)


@validate_params(
Expand Down
37 changes: 34 additions & 3 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.stats import bernoulli

from sklearn import datasets, svm
from sklearn.base import config_context
from sklearn.datasets import make_multilabel_classification
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (
Expand Down Expand Up @@ -43,8 +44,16 @@
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelBinarizer, label_binarize
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils._array_api import (
device as array_api_device,
)
from sklearn.utils._array_api import (
get_namespace,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._testing import (
_array_api_for_tests,
assert_allclose,
assert_almost_equal,
assert_array_almost_equal,
Expand Down Expand Up @@ -1269,7 +1278,7 @@ def test_confusion_matrix_multiclass_subset_labels():
@pytest.mark.parametrize(
"labels, err_msg",
[
([], "'labels' should contains at least one label."),
([], "'labels' should contain at least one label."),
([3, 4], "At least one label specified must be in y_true"),
],
ids=["empty list", "unknown labels"],
Expand All @@ -1283,10 +1292,14 @@ def test_confusion_matrix_error(labels, err_msg):
@pytest.mark.parametrize(
"labels", (None, [0, 1], [0, 1, 2]), ids=["None", "binary", "multiclass"]
)
def test_confusion_matrix_on_zero_length_input(labels):
@pytest.mark.parametrize(
"sample_weight",
(None, []),
)
def test_confusion_matrix_on_zero_length_input(labels, sample_weight):
expected_n_classes = len(labels) if labels else 0
expected = np.zeros((expected_n_classes, expected_n_classes), dtype=int)
cm = confusion_matrix([], [], labels=labels)
cm = confusion_matrix([], [], sample_weight=sample_weight, labels=labels)
assert_array_equal(cm, expected)


Expand Down Expand Up @@ -3608,3 +3621,21 @@ def test_d2_brier_score_warning_on_less_than_two_samples():
warning_message = "not well-defined with less than two samples"
with pytest.warns(UndefinedMetricWarning, match=warning_message):
d2_brier_score(y_true, y_pred)


@pytest.mark.parametrize(
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
)
def test_confusion_matrix_array_api(array_namespace, device, _):
"""Test that `confusion_matrix` works for all array types when `labels` are passed
such that the inner boolean `need_index_conversion` evaluates to `True`."""
xp = _array_api_for_tests(array_namespace, device)

y_true = xp.asarray([1, 2, 3], device=device)
y_pred = xp.asarray([4, 5, 6], device=device)
labels = xp.asarray([1, 2, 3], device=device)

with config_context(array_api_dispatch=True):
result = confusion_matrix(y_true, y_pred, labels=labels)
assert get_namespace(result)[0] == get_namespace(y_pred)[0]
assert array_api_device(result) == array_api_device(y_pred)
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_multiclass_classification_metric,
check_array_api_multilabel_classification_metric,
],
confusion_matrix: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
f1_score: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
Expand Down