diff --git a/doc/modules/array_api.rst b/doc/modules/array_api.rst index 2f6e16a89a9ea..78eef9b392356 100644 --- a/doc/modules/array_api.rst +++ b/doc/modules/array_api.rst @@ -141,6 +141,7 @@ Metrics ------- - :func:`sklearn.metrics.accuracy_score` +- :func:`sklearn.metrics.confusion_matrix` - :func:`sklearn.metrics.d2_tweedie_score` - :func:`sklearn.metrics.explained_variance_score` - :func:`sklearn.metrics.f1_score` diff --git a/doc/whats_new/upcoming_changes/array-api/30440.feature.rst b/doc/whats_new/upcoming_changes/array-api/30440.feature.rst new file mode 100644 index 0000000000000..d1f1374f28577 --- /dev/null +++ b/doc/whats_new/upcoming_changes/array-api/30440.feature.rst @@ -0,0 +1,2 @@ +- :func:`sklearn.metrics.confusion_matrix` now supports Array API compatible inputs. + by :user:`Stefanie Senger ` diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 06503046790be..e3cc6eee24af5 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -29,15 +29,19 @@ from ..utils._array_api import ( _average, _bincount, + _convert_to_numpy, _count_nonzero, _find_matching_floating_dtype, _is_numpy_namespace, + _isin, _max_precision_float_dtype, + _nan_to_num, _searchsorted, _tolist, _union1d, get_namespace, get_namespace_and_device, + size, xpx, ) from ..utils._param_validation import ( @@ -401,7 +405,7 @@ def confusion_matrix( y_pred : array-like of shape (n_samples,) Estimated targets as returned by a classifier. - labels : array-like of shape (n_classes), default=None + labels : array-like of shape (n_classes,), default=None List of labels to index the matrix. This may be used to reorder or select a subset of labels. If ``None`` is given, those that appear at least once @@ -419,7 +423,7 @@ def confusion_matrix( Returns ------- - C : ndarray of shape (n_classes, n_classes) + C : array of shape (n_classes, n_classes) Confusion matrix whose i-th row and j-th column entry indicates the number of samples with true label being i-th class @@ -464,6 +468,7 @@ def confusion_matrix( (0, 2, 1, 1) """ y_true, y_pred = attach_unique(y_true, y_pred) + xp, _, device_ = get_namespace_and_device(y_true, y_pred, labels, sample_weight) y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in ("binary", "multiclass"): raise ValueError("%s is not supported" % y_type) @@ -471,55 +476,69 @@ def confusion_matrix( if labels is None: labels = unique_labels(y_true, y_pred) else: - labels = np.asarray(labels) + labels = xp.asarray(labels) n_labels = labels.size if n_labels == 0: - raise ValueError("'labels' should contains at least one label.") + raise ValueError("'labels' should contain at least one label.") elif y_true.size == 0: - return np.zeros((n_labels, n_labels), dtype=int) - elif len(np.intersect1d(y_true, labels)) == 0: + return xp.zeros((n_labels, n_labels), dtype=xp.int64, device=device_) + elif not _isin(labels, y_true, xp=xp).any(): raise ValueError("At least one label specified must be in y_true") if sample_weight is None: - sample_weight = np.ones(y_true.shape[0], dtype=np.int64) + sample_weight = xp.ones(y_true.shape[0], dtype=xp.int64, device=device_) else: - sample_weight = np.asarray(sample_weight) + sample_weight = xp.asarray(sample_weight, device=device_) check_consistent_length(y_true, y_pred, sample_weight) - n_labels = labels.size + n_labels = size(labels) # If labels are not consecutive integers starting from zero, then # y_true and y_pred must be converted into index form need_index_conversion = not ( - labels.dtype.kind in {"i", "u", "b"} - and np.all(labels == np.arange(n_labels)) - and y_true.min() >= 0 - and y_pred.min() >= 0 + xp.isdtype(labels.dtype, ("signed integer", "unsigned integer", "bool")) + and xp.all(labels == xp.arange(n_labels, device=device_)) + and xp.min(y_true) >= 0 + and xp.min(y_pred) >= 0 ) if need_index_conversion: - label_to_ind = {y: x for x, y in enumerate(labels)} - y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred]) - y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true]) + # convert 0D array into scalar type, see https://github.com/data-apis/array-api-strict/issues/109: + if xp.isdtype(labels.dtype, ("real floating")): + scalar_dtype = float + else: + scalar_dtype = str + label_to_ind = {scalar_dtype(entry): idx for idx, entry in enumerate(labels)} + y_pred = xp.asarray( + [label_to_ind.get(scalar_dtype(x), n_labels + 1) for x in y_pred], + device=device_, + ) + y_true = xp.asarray( + [label_to_ind.get(scalar_dtype(x), n_labels + 1) for x in y_true], + device=device_, + ) # intersect y_pred, y_true with labels, eliminate items not in labels - ind = np.logical_and(y_pred < n_labels, y_true < n_labels) - if not np.all(ind): + ind = xp.logical_and(y_pred < n_labels, y_true < n_labels) + if not xp.all(ind): y_pred = y_pred[ind] y_true = y_true[ind] # also eliminate weights of eliminated items sample_weight = sample_weight[ind] # Choose the accumulator dtype to always have high precision - if sample_weight.dtype.kind in {"i", "u", "b"}: + if xp.isdtype(sample_weight.dtype, ("signed integer", "unsigned integer", "bool")): dtype = np.int64 else: dtype = np.float64 - cm = coo_matrix( - (sample_weight, (y_true, y_pred)), + ( + _convert_to_numpy(sample_weight, xp=xp), + (_convert_to_numpy(y_true, xp=xp), _convert_to_numpy(y_pred, xp=xp)), + ), shape=(n_labels, n_labels), dtype=dtype, ).toarray() + cm = xp.asarray(cm) with np.errstate(all="ignore"): if normalize == "true": @@ -528,7 +547,7 @@ def confusion_matrix( cm = cm / cm.sum(axis=0, keepdims=True) elif normalize == "all": cm = cm / cm.sum() - cm = np.nan_to_num(cm) + cm = _nan_to_num(cm) if cm.shape == (1, 1): warnings.warn( diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index b66353e5ecfab..20e6a140b20e8 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -10,6 +10,7 @@ from scipy.stats import bernoulli from sklearn import datasets, svm +from sklearn.base import config_context from sklearn.datasets import make_multilabel_classification from sklearn.exceptions import UndefinedMetricWarning from sklearn.metrics import ( @@ -39,8 +40,10 @@ from sklearn.model_selection import cross_val_score from sklearn.preprocessing import LabelBinarizer, label_binarize from sklearn.tree import DecisionTreeClassifier +from sklearn.utils._array_api import yield_namespace_device_dtype_combinations from sklearn.utils._mocking import MockDataFrame from sklearn.utils._testing import ( + _array_api_for_tests, assert_allclose, assert_almost_equal, assert_array_almost_equal, @@ -1265,7 +1268,7 @@ def test_confusion_matrix_multiclass_subset_labels(): @pytest.mark.parametrize( "labels, err_msg", [ - ([], "'labels' should contains at least one label."), + ([], "'labels' should contain at least one label."), ([3, 4], "At least one label specified must be in y_true"), ], ids=["empty list", "unknown labels"], @@ -3395,3 +3398,19 @@ def test_d2_log_loss_score_raises(): err = "The labels array needs to contain at least two" with pytest.raises(ValueError, match=err): d2_log_loss_score(y_true, y_pred, labels=labels) + + +@pytest.mark.parametrize( + "array_namespace, device, _", yield_namespace_device_dtype_combinations() +) +def test_confusion_matrix_array_api(array_namespace, device, _): + """Test that `confusion_matrix` works for all array types if need_index_conversion + evaluates to `True`and that it raises if not at least one label from `y_pred` is in + `y_true`.""" + xp = _array_api_for_tests(array_namespace, device) + + y_true = xp.asarray([1, 2, 3], device=device) + y_pred = xp.asarray([4, 5, 6], device=device) + + with config_context(array_api_dispatch=True): + confusion_matrix(y_true, y_pred) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 5cdc2ead54740..64588009a4cd2 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -2193,6 +2193,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name) check_array_api_multiclass_classification_metric, check_array_api_multilabel_classification_metric, ], + confusion_matrix: [ + check_array_api_binary_classification_metric, + check_array_api_multiclass_classification_metric, + ], f1_score: [ check_array_api_binary_classification_metric, check_array_api_multiclass_classification_metric, diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 7b22b1a19ca46..5e1ad9abd41cc 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -1082,6 +1082,24 @@ def _tolist(array, xp=None): return [element.item() for element in array_np] +def _nan_to_num(array, xp=None): + """Substitutes NaN values of an array with 0 and inf values with the maximum or + minimum numbers available for the dtype respectively; like np.nan_to_num.""" + xp, _ = get_namespace(array, xp=xp) + try: + array = xp.nan_to_num(array) + except AttributeError: # currently catching exceptions from array_api_strict + array[xp.isnan(array)] = 0 + if xp.isdtype(array.dtype, "real floating"): + array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max + array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min + else: # xp.isdtype(array.dtype, "integral") + array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max + array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min + return array + + + def _logsumexp(array, axis=None, xp=None): # TODO replace by scipy.special.logsumexp when # https://github.com/scipy/scipy/pull/22683 is part of a release.