diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index b50815e1f7fb3..a77722100ee85 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -132,6 +132,7 @@ Metrics - :func:`sklearn.metrics.cluster.entropy` - :func:`sklearn.metrics.accuracy_score` +- :func:`sklearn.metrics.confusion_matrix` - :func:`sklearn.metrics.d2_tweedie_score` - :func:`sklearn.metrics.explained_variance_score` - :func:`sklearn.metrics.f1_score` diff --git a/doc/whats_new/upcoming_changes/array-api/30562.feature.rst b/doc/whats_new/upcoming_changes/array-api/30562.feature.rst new file mode 100644 index 0000000000000..3c1a58d90bfe5 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/30562.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs. + By :user:`Stefanie Senger ` diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 2a08a1893766e..5a3d6d2e41781 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -29,6 +29,7 @@ from ..utils._array_api import ( _average, _bincount, + _convert_to_numpy, _count_nonzero, _find_matching_floating_dtype, _is_numpy_namespace, @@ -275,7 +276,7 @@ def confusion_matrix( y_pred : array-like of shape (n_samples,) Estimated targets as returned by a classifier. - labels : array-like of shape (n_classes), default=None + labels : array-like of shape (n_classes,), default=None List of labels to index the matrix. This may be used to reorder or select a subset of labels. If ``None`` is given, those that appear at least once @@ -337,6 +338,23 @@ def confusion_matrix( >>> (tn, fp, fn, tp) (0, 2, 1, 1) """ + xp, _ = get_namespace(y_true, y_pred, labels, sample_weight) + y_true = check_array( + y_true, + dtype=None, + ensure_2d=False, + ensure_all_finite=False, + ensure_min_samples=0, + ) + y_pred = check_array( + y_pred, + dtype=None, + ensure_2d=False, + ensure_all_finite=False, + ensure_min_samples=0, + ) + y_true = _convert_to_numpy(y_true, xp) + y_pred = _convert_to_numpy(y_pred, xp) y_true, y_pred = attach_unique(y_true, y_pred) y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in ("binary", "multiclass"): @@ -345,10 +363,13 @@ def confusion_matrix( if labels is None: labels = unique_labels(y_true, y_pred) else: - labels = np.asarray(labels) + if not _is_numpy_namespace(get_namespace(labels)[0]): + labels = _convert_to_numpy(labels, xp) + else: + labels = np.asarray(labels) n_labels = labels.size if n_labels == 0: - raise ValueError("'labels' should contains at least one label.") + raise ValueError("'labels' should contain at least one label.") elif y_true.size == 0: return np.zeros((n_labels, n_labels), dtype=int) elif len(np.intersect1d(y_true, labels)) == 0: @@ -357,7 +378,7 @@ def confusion_matrix( if sample_weight is None: sample_weight = np.ones(y_true.shape[0], dtype=np.int64) else: - sample_weight = np.asarray(sample_weight) + sample_weight = _convert_to_numpy(sample_weight, xp) check_consistent_length(y_true, y_pred, sample_weight) @@ -371,9 +392,9 @@ def confusion_matrix( and y_pred.min() >= 0 ) if need_index_conversion: - label_to_ind = {y: x for x, y in enumerate(labels)} - y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) - y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) + label_to_ind = {label: index for index, label in enumerate(labels)} + y_pred = np.array([label_to_ind.get(label, n_labels + 1) for label in y_pred]) + y_true = np.array([label_to_ind.get(label, n_labels + 1) for label in y_true]) # intersect y_pred, y_true with labels, eliminate items not in labels ind = np.logical_and(y_pred < n_labels, y_true < n_labels) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 21e2eed9b53cc..4fdea34144002 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -10,6 +10,7 @@ from scipy.stats import bernoulli from sklearn import datasets, svm +from sklearn.base import config_context from sklearn.datasets import make_multilabel_classification from sklearn.exceptions import UndefinedMetricWarning from sklearn.metrics import ( @@ -39,8 +40,14 @@ from sklearn.model_selection import cross_val_score from sklearn.preprocessing import LabelBinarizer, label_binarize from sklearn.tree import DecisionTreeClassifier +from sklearn.utils._array_api import ( + _is_numpy_namespace, + get_namespace, + yield_namespace_device_dtype_combinations, +) from sklearn.utils._mocking import MockDataFrame from sklearn.utils._testing import ( + _array_api_for_tests, assert_allclose, assert_almost_equal, assert_array_almost_equal, @@ -1273,7 +1280,7 @@ def test_confusion_matrix_multiclass_subset_labels(): @pytest.mark.parametrize( "labels, err_msg", [ - ([], "'labels' should contains at least one label."), + ([], "'labels' should contain at least one label."), ([3, 4], "At least one label specified must be in y_true"), ], ids=["empty list", "unknown labels"], @@ -3226,3 +3233,26 @@ def test_d2_log_loss_score_raises(): err = "The labels array needs to contain at least two" with pytest.raises(ValueError, match=err): d2_log_loss_score(y_true, y_pred, labels=labels) + + +@pytest.mark.parametrize( + "array_namespace, device, _", yield_namespace_device_dtype_combinations() +) +def test_confusion_matrix_array_api(array_namespace, device, _): + """Test that `confusion_matrix` works for all array types when `labels` are passed + such that the inner boolean `need_index_conversion` evaluates to `True`.""" + xp = _array_api_for_tests(array_namespace, device) + + y_true = xp.asarray([1, 2, 3], device=device) + y_pred = xp.asarray([4, 5, 6], device=device) + labels = xp.asarray([1, 2, 3], device=device) + + with config_context(array_api_dispatch=True): + result = confusion_matrix(y_true, y_pred, labels=labels) + xp_result, _ = get_namespace(result) + assert _is_numpy_namespace(xp_result) + + # Since the computation always happens with NumPy / SciPy on the CPU, this + # function is expected to return an array allocated on the CPU even when it does + # not match the input array's device. + assert result.device == "cpu" diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 9e8d0ce116394..236dcdfac9533 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2097,6 +2097,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_multiclass_classification_metric, check_array_api_multilabel_classification_metric, ], + confusion_matrix: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + ], f1_score: [ check_array_api_binary_classification_metric, check_array_api_multiclass_classification_metric, diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 65503a0674a70..4996e2818a623 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -861,7 +861,7 @@ def _ravel(array, xp=None): def _convert_to_numpy(array, xp): - """Convert X into a NumPy ndarray on the CPU.""" + """Convert array into a NumPy ndarray on the CPU.""" xp_name = xp.__name__ if xp_name in {"array_api_compat.torch", "torch"}: