diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 33c14f4eb0316..5fd2c1f639e06 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -377,6 +377,15 @@ Changelog :func:`sklearn.metrics.zero_one_loss` now support Array API compatible inputs. :pr:`27137` by :user:`Edoardo Abati `. +- |Enhancement| |Fix| Added a `pos_label` to :func:`metrics.confusion_matrix` + avoiding ambiguity regarding the position of the positive class label in the + matrix. An error is raised if the positive label cannot be set to `1` and + also if the `pos_label` is set on other classification + problems than binary. `pos_label` is also added to + :meth:`metrics.ConfusionMatrixDisplay.from_estimator` and + :meth:`metrics.ConfusionMatrixDisplay.from_predictions`. + :pr:`26839` by :user:`Guillaume Lemaitre `. + - |API| Deprecated `needs_threshold` and `needs_proba` from :func:`metrics.make_scorer`. These parameters will be removed in version 1.6. Instead, use `response_method` that accepts `"predict"`, `"predict_proba"` or `"decision_function"` or a list of such diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index cd485c581bc54..fb39b9388a41f 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -226,13 +226,14 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): "y_true": ["array-like"], "y_pred": ["array-like"], "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], "sample_weight": ["array-like", None], "normalize": [StrOptions({"true", "pred", "all"}), None], }, prefer_skip_nested_validation=True, ) def confusion_matrix( - y_true, y_pred, *, labels=None, sample_weight=None, normalize=None + y_true, y_pred, *, labels=None, pos_label=None, sample_weight=None, normalize=None ): """Compute confusion matrix to evaluate the accuracy of a classification. @@ -260,6 +261,15 @@ def confusion_matrix( If ``None`` is given, those that appear at least once in ``y_true`` or ``y_pred`` are used in sorted order. + pos_label : int, float, bool or str, default=None + The label of the positive class for binary classification. + When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, + `pos_label` is set to 1, otherwise an error will be raised. + An error is also raised if `pos_label` is set and `y_true` is not a binary + classification problem. + + .. versionadded:: 1.4 + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -320,6 +330,19 @@ def confusion_matrix( if y_type not in ("binary", "multiclass"): raise ValueError("%s is not supported" % y_type) + if y_true.size == 0 and y_pred.size == 0: + # early return for empty arrays avoiding all checks + n_classes = 0 if labels is None else len(labels) + return np.zeros((n_classes, n_classes), dtype=int) + + if y_type == "binary": + pos_label = _check_pos_label_consistency(pos_label, y_true) + elif pos_label is not None: + raise ValueError( + "`pos_label` should only be set when the target is binary. Got " + f"{y_type} type of target instead." + ) + if labels is None: labels = unique_labels(y_true, y_pred) else: @@ -382,6 +405,11 @@ def confusion_matrix( cm = cm / cm.sum() cm = np.nan_to_num(cm) + if pos_label is not None and pos_label != labels[-1]: + # Reorder the confusion matrix such that TP is at index + # [1, 1]. + cm = cm[::-1, ::-1] + if cm.shape == (1, 1): warnings.warn( ( @@ -680,7 +708,17 @@ class labels [2]_. .. [3] `Wikipedia entry for the Cohen's kappa `_. """ - confusion = confusion_matrix(y1, y2, labels=labels, sample_weight=sample_weight) + y_type, y1, y2 = _check_targets(y1, y2) + if y_type == "binary": + # we can set `pos_label` to any class labels because the computation of MCC + # is symmetric and invariant to `pos_label` switch. + pos_label = y1[0] + else: + pos_label = None + + confusion = confusion_matrix( + y1, y2, labels=labels, pos_label=pos_label, sample_weight=sample_weight + ) n_classes = confusion.shape[0] sum0 = np.sum(confusion, axis=0) sum1 = np.sum(confusion, axis=1) @@ -966,12 +1004,21 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): if y_type not in {"binary", "multiclass"}: raise ValueError("%s is not supported" % y_type) + if y_type == "binary": + # we can set `pos_label` to any class labels because the computation of MCC + # is symmetric and invariant to `pos_label` switch. + pos_label = y_true[0] + else: + pos_label = None + lb = LabelEncoder() lb.fit(np.hstack([y_true, y_pred])) y_true = lb.transform(y_true) y_pred = lb.transform(y_pred) - C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) + C = confusion_matrix( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) t_sum = C.sum(axis=1, dtype=np.float64) p_sum = C.sum(axis=0, dtype=np.float64) n_correct = np.trace(C, dtype=np.float64) @@ -1921,11 +1968,18 @@ class after being classified as negative. This is the case when the f"problems, got targets of type: {y_type}" ) + if labels is None: + classes = np.unique(y_true) + pos_label = 1 if len(classes) < 2 else classes[1] + else: + pos_label = labels[-1] + cm = confusion_matrix( y_true, y_pred, sample_weight=sample_weight, labels=labels, + pos_label=pos_label, ) # Case when `y_test` contains a single class and `y_test == y_pred`. @@ -2396,7 +2450,17 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals >>> balanced_accuracy_score(y_true, y_pred) 0.625 """ - C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) + y_type, y_true, y_pred = _check_targets(y_true, y_pred) + if y_type == "binary": + # We can set `pos_label` to any value since we are computing per-class + # statistics and averaging them. + pos_label = y_true[0] + else: + pos_label = None + + C = confusion_matrix( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) with np.errstate(divide="ignore", invalid="ignore"): per_class = np.diag(C) / C.sum(axis=1) if np.any(np.isnan(per_class)): diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index f0bda0dc73d39..6c72ba2511a39 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -202,6 +202,7 @@ def from_estimator( y, *, labels=None, + pos_label=None, sample_weight=None, normalize=None, display_labels=None, @@ -238,6 +239,15 @@ def from_estimator( that appear at least once in `y_true` or `y_pred` are used in sorted order. + pos_label : int, float, bool or str, default=None + The label of the positive class for binary classification. + When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, + `pos_label` is set to 1, otherwise an error will be raised. + An error is also raised if `pos_label` is set and `y_true` is not a binary + classification problem. + + .. versionadded:: 1.4 + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -323,6 +333,7 @@ def from_estimator( y, y_pred, sample_weight=sample_weight, + pos_label=pos_label, labels=labels, normalize=normalize, display_labels=display_labels, @@ -343,6 +354,7 @@ def from_predictions( y_pred, *, labels=None, + pos_label=None, sample_weight=None, normalize=None, display_labels=None, @@ -376,6 +388,15 @@ def from_predictions( that appear at least once in `y_true` or `y_pred` are used in sorted order. + pos_label : int, float, bool or str, default=None + The label of the positive class for binary classification. + When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, + `pos_label` is set to 1, otherwise an error will be raised. + An error is also raised if `pos_label` is set and `y_true` is not a binary + classification problem. + + .. versionadded:: 1.4 + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -465,6 +486,7 @@ def from_predictions( y_pred, sample_weight=sample_weight, labels=labels, + pos_label=pos_label, normalize=normalize, ) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 895e10ca851a6..84208ab1f3006 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -10,7 +10,7 @@ from scipy.stats import bernoulli from sklearn import datasets, svm -from sklearn.datasets import make_multilabel_classification +from sklearn.datasets import make_classification, make_multilabel_classification from sklearn.exceptions import UndefinedMetricWarning from sklearn.metrics import ( accuracy_score, @@ -457,13 +457,26 @@ def test_precision_recall_f_unused_pos_label(): ) -def test_confusion_matrix_binary(): +def test_confusion_matrix_pos_label_error(): + _, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=0) + err_msg = "`pos_label` should only be set when the target is binary." + with pytest.raises(ValueError, match=err_msg): + confusion_matrix(y, y, pos_label=1) + + +@pytest.mark.parametrize("pos_label", [0, 1]) +def test_confusion_matrix_binary(pos_label): # Test confusion matrix - binary classification case y_true, y_pred, _ = make_prediction(binary=True) - def test(y_true, y_pred): - cm = confusion_matrix(y_true, y_pred) - assert_array_equal(cm, [[22, 3], [8, 17]]) + def test(y_true, y_pred, pos_label): + cm = confusion_matrix(y_true, y_pred, pos_label=pos_label) + expected_cm = np.array([[22, 3], [8, 17]]) + if pos_label in {"0", 0}: + # we should flip the confusion matrix to respect the documentation + # of tp, fp, fn, tn + expected_cm = expected_cm[::-1, ::-1] + assert_array_equal(cm, expected_cm) tp, fp, fn, tn = cm.flatten() num = tp * tn - fp * fn @@ -474,8 +487,8 @@ def test(y_true, y_pred): assert_array_almost_equal(mcc, true_mcc, decimal=2) assert_array_almost_equal(mcc, 0.57, decimal=2) - test(y_true, y_pred) - test([str(y) for y in y_true], [str(y) for y in y_pred]) + test(y_true, y_pred, pos_label) + test([str(y) for y in y_true], [str(y) for y in y_pred], str(pos_label)) def test_multilabel_confusion_matrix_binary(): diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index af652d1c90b41..c7092cec02244 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -324,6 +324,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # Metrics with a "pos_label" argument METRICS_WITH_POS_LABEL = { + "unnormalized_confusion_matrix", + "normalized_confusion_matrix", "roc_curve", "precision_recall_curve", "det_curve",