From 15b8b981b7b3824b9836ddbcd5ca5d7ca686f952 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 15 Jul 2023 01:18:59 +0200 Subject: [PATCH 01/10] ENH add pos_label to confusion_matrix --- doc/whats_new/v1.4.rst | 7 +++++++ sklearn/metrics/_classification.py | 11 ++++++++++- sklearn/metrics/tests/test_classification.py | 18 ++++++++++++------ 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c2b7d19404af9..c58abd816b209 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -111,6 +111,13 @@ Changelog to :ref:`metadata routing user guide `. :pr:`26789` by `Adrin Jalali`_. +:mod:`sklearn.metrics` +...................... + +- |Enhancement| add a `pos_label` to :func:`metrics.confusion_matrix` such that + we always report the TP/TN/FP/FN in the same position in the matrix. + :pr:`xxx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 971ea5a25ffe3..19c1e331a2653 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -226,13 +226,14 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): "y_true": ["array-like"], "y_pred": ["array-like"], "labels": ["array-like", None], + "pos_label": [Real, str, "boolean", None], "sample_weight": ["array-like", None], "normalize": [StrOptions({"true", "pred", "all"}), None], }, prefer_skip_nested_validation=True, ) def confusion_matrix( - y_true, y_pred, *, labels=None, sample_weight=None, normalize=None + y_true, y_pred, *, labels=None, pos_label=1, sample_weight=None, normalize=None ): """Compute confusion matrix to evaluate the accuracy of a classification. @@ -260,6 +261,11 @@ def confusion_matrix( If ``None`` is given, those that appear at least once in ``y_true`` or ``y_pred`` are used in sorted order. + pos_label : int, float, bool or str, default=1 + Only taken into account when the data is binary. Ignore otherwise. + + .. versionadded:: 1.4 + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -382,6 +388,9 @@ def confusion_matrix( cm = cm / cm.sum() cm = np.nan_to_num(cm) + if n_labels == 2 and pos_label != labels[-1]: + cm = cm[::-1, ::-1] + return cm diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index a05a532ecb3f2..0a7b816171c3a 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -453,13 +453,19 @@ def test_precision_recall_f_unused_pos_label(): ) -def test_confusion_matrix_binary(): +@pytest.mark.parametrize("pos_label", [0, 1]) +def test_confusion_matrix_binary(pos_label): # Test confusion matrix - binary classification case y_true, y_pred, _ = make_prediction(binary=True) - def test(y_true, y_pred): - cm = confusion_matrix(y_true, y_pred) - assert_array_equal(cm, [[22, 3], [8, 17]]) + def test(y_true, y_pred, pos_label): + cm = confusion_matrix(y_true, y_pred, pos_label=pos_label) + expected_cm = np.array([[22, 3], [8, 17]]) + if pos_label in {"0", 0}: + # we should flip the confusion matrix to respect the documentation + # of tp, fp, fn, tn + expected_cm = expected_cm[::-1, ::-1] + assert_array_equal(cm, expected_cm) tp, fp, fn, tn = cm.flatten() num = tp * tn - fp * fn @@ -470,8 +476,8 @@ def test(y_true, y_pred): assert_array_almost_equal(mcc, true_mcc, decimal=2) assert_array_almost_equal(mcc, 0.57, decimal=2) - test(y_true, y_pred) - test([str(y) for y in y_true], [str(y) for y in y_pred]) + test(y_true, y_pred, pos_label) + test([str(y) for y in y_true], [str(y) for y in y_pred], str(pos_label)) def test_multilabel_confusion_matrix_binary(): From 1ad6db8830f40ecb065e2991a02fc065a471734d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 15 Jul 2023 01:22:18 +0200 Subject: [PATCH 02/10] update changelog --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c58abd816b209..a6d34c541ce9e 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -116,7 +116,7 @@ Changelog - |Enhancement| add a `pos_label` to :func:`metrics.confusion_matrix` such that we always report the TP/TN/FP/FN in the same position in the matrix. - :pr:`xxx` by :user:`Guillaume Lemaitre `. + :pr:`26839` by :user:`Guillaume Lemaitre `. :mod:`sklearn.model_selection` .............................. From 6a05658914270c8c09ae807c1ee26fa741bcc890 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 15 Jul 2023 11:59:53 +0200 Subject: [PATCH 03/10] Update sklearn/metrics/_classification.py Co-authored-by: Adrin Jalali --- sklearn/metrics/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 19c1e331a2653..4ba7e569765d3 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -262,7 +262,7 @@ def confusion_matrix( in ``y_true`` or ``y_pred`` are used in sorted order. pos_label : int, float, bool or str, default=1 - Only taken into account when the data is binary. Ignore otherwise. + Only taken into account when the data is binary. Ignored otherwise. .. versionadded:: 1.4 From 72767e7e50101701716f7f56ff821ae11c9bbd86 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 15 Jul 2023 22:57:39 +0200 Subject: [PATCH 04/10] iter --- doc/whats_new/v1.4.rst | 7 +++-- sklearn/metrics/_classification.py | 50 +++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index a6d34c541ce9e..c42cc706e6646 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -114,8 +114,11 @@ Changelog :mod:`sklearn.metrics` ...................... -- |Enhancement| add a `pos_label` to :func:`metrics.confusion_matrix` such that - we always report the TP/TN/FP/FN in the same position in the matrix. +- |Enhancement| |Fix| add a `pos_label` to :func:`metrics.confusion_matrix` + avoiding ambiguity regarding the position of the positive class label in the + matrix. An error is raised if the positive label cannot be set to `1`. An + error is also raised if the `pos_label` is set on other classification + problem than binary. :pr:`26839` by :user:`Guillaume Lemaitre `. :mod:`sklearn.model_selection` diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 19c1e331a2653..e9122cd024782 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -233,7 +233,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): prefer_skip_nested_validation=True, ) def confusion_matrix( - y_true, y_pred, *, labels=None, pos_label=1, sample_weight=None, normalize=None + y_true, y_pred, *, labels=None, pos_label=None, sample_weight=None, normalize=None ): """Compute confusion matrix to evaluate the accuracy of a classification. @@ -261,8 +261,12 @@ def confusion_matrix( If ``None`` is given, those that appear at least once in ``y_true`` or ``y_pred`` are used in sorted order. - pos_label : int, float, bool or str, default=1 - Only taken into account when the data is binary. Ignore otherwise. + pos_label : int, float, bool or str, default=None + The label of the positive class for binary classification. + When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, + `pos_label` is set to 1, otherwise an error will be raised. + An error is raised if `pos_label` is set and `y_true` is not a binary + classification problem. .. versionadded:: 1.4 @@ -322,10 +326,23 @@ def confusion_matrix( >>> (tn, fp, fn, tp) (0, 2, 1, 1) """ + if len(y_true) == 0 and len(y_pred) == 0: + # early return for empty arrays avoiding all checks + n_classes = 0 if labels is None else len(labels) + return np.zeros((n_classes, n_classes), dtype=int) + y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in ("binary", "multiclass"): raise ValueError("%s is not supported" % y_type) + if y_type == "binary": + pos_label = _check_pos_label_consistency(pos_label, y_true) + elif pos_label is not None: + raise ValueError( + "`pos_label` should only be set when the target is binary. Got " + f"{y_type} type of target instead." + ) + if labels is None: labels = unique_labels(y_true, y_pred) else: @@ -388,7 +405,7 @@ def confusion_matrix( cm = cm / cm.sum() cm = np.nan_to_num(cm) - if n_labels == 2 and pos_label != labels[-1]: + if pos_label is not None and pos_label != labels[-1]: cm = cm[::-1, ::-1] return cm @@ -897,7 +914,7 @@ def jaccard_score( }, prefer_skip_nested_validation=True, ) -def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): +def matthews_corrcoef(y_true, y_pred, *, pos_label=None, sample_weight=None): """Compute the Matthews correlation coefficient (MCC). The Matthews correlation coefficient is used in machine learning as a @@ -965,12 +982,21 @@ def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): if y_type not in {"binary", "multiclass"}: raise ValueError("%s is not supported" % y_type) + if y_type == "binary": + # we can set `pos_label` to any class labels because the computation of MCC + # is symmetric and invariant to `pos_label` switch. + pos_label = y_true[0] + else: + pos_label = None + lb = LabelEncoder() lb.fit(np.hstack([y_true, y_pred])) y_true = lb.transform(y_true) y_pred = lb.transform(y_pred) - C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) + C = confusion_matrix( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) t_sum = C.sum(axis=1, dtype=np.float64) p_sum = C.sum(axis=0, dtype=np.float64) n_correct = np.trace(C, dtype=np.float64) @@ -2389,7 +2415,17 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals >>> balanced_accuracy_score(y_true, y_pred) 0.625 """ - C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) + y_type, y_true, y_pred = _check_targets(y_true, y_pred) + if y_type == "binary": + # We can set `pos_label` to any values since we are computing per-class + # statistics and average them. + pos_label = y_true[0] + else: + pos_label = None + + C = confusion_matrix( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) with np.errstate(divide="ignore", invalid="ignore"): per_class = np.diag(C) / C.sum(axis=1) if np.any(np.isnan(per_class)): From 47b0f338c585d12b224a00626e541ab08ec7f7b3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 15 Jul 2023 23:00:05 +0200 Subject: [PATCH 05/10] iter --- sklearn/metrics/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index e9122cd024782..5fc6176462fa4 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -914,7 +914,7 @@ def jaccard_score( }, prefer_skip_nested_validation=True, ) -def matthews_corrcoef(y_true, y_pred, *, pos_label=None, sample_weight=None): +def matthews_corrcoef(y_true, y_pred, *, sample_weight=None): """Compute the Matthews correlation coefficient (MCC). The Matthews correlation coefficient is used in machine learning as a From 6ca1b39685e91d7a834c30b119c18c18bd4e505c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sat, 15 Jul 2023 23:33:24 +0200 Subject: [PATCH 06/10] iter --- doc/whats_new/v1.4.rst | 18 ++++++++++-------- sklearn/metrics/_plot/confusion_matrix.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c42cc706e6646..e595b25bcc3be 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -104,13 +104,6 @@ Changelog ``**score_params`` which are passed to the underlying scorer. :pr:`26525` by :user:`Omar Salman `. -:mod:`sklearn.pipeline` -....................... - -- |Feature| :class:`pipeline.Pipeline` now supports metadata routing according - to :ref:`metadata routing user guide `. :pr:`26789` by - `Adrin Jalali`_. - :mod:`sklearn.metrics` ...................... @@ -118,7 +111,9 @@ Changelog avoiding ambiguity regarding the position of the positive class label in the matrix. An error is raised if the positive label cannot be set to `1`. An error is also raised if the `pos_label` is set on other classification - problem than binary. + problem than binary. `pos_label` is also added to + :meth:`metrics.ConfusionMatrixDisplay.from_estimator` and + :meth:`metrics.ConfusionMatrixDisplay.from_predictions`. :pr:`26839` by :user:`Guillaume Lemaitre `. :mod:`sklearn.model_selection` @@ -130,6 +125,13 @@ Changelog object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin Jalali`_. +:mod:`sklearn.pipeline` +....................... + +- |Feature| :class:`pipeline.Pipeline` now supports metadata routing according + to :ref:`metadata routing user guide `. :pr:`26789` by + `Adrin Jalali`_. + :mod:`sklearn.tree` ................... diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index f0bda0dc73d39..a4597d5871338 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -202,6 +202,7 @@ def from_estimator( y, *, labels=None, + pos_label=None, sample_weight=None, normalize=None, display_labels=None, @@ -238,6 +239,15 @@ def from_estimator( that appear at least once in `y_true` or `y_pred` are used in sorted order. + pos_label : int, float, bool or str, default=None + The label of the positive class for binary classification. + When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, + `pos_label` is set to 1, otherwise an error will be raised. + An error is raised if `pos_label` is set and `y_true` is not a binary + classification problem. + + .. versionadded:: 1.4 + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -323,6 +333,7 @@ def from_estimator( y, y_pred, sample_weight=sample_weight, + pos_label=pos_label, labels=labels, normalize=normalize, display_labels=display_labels, @@ -343,6 +354,7 @@ def from_predictions( y_pred, *, labels=None, + pos_label=None, sample_weight=None, normalize=None, display_labels=None, @@ -376,6 +388,15 @@ def from_predictions( that appear at least once in `y_true` or `y_pred` are used in sorted order. + pos_label : int, float, bool or str, default=None + The label of the positive class for binary classification. + When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, + `pos_label` is set to 1, otherwise an error will be raised. + An error is raised if `pos_label` is set and `y_true` is not a binary + classification problem. + + .. versionadded:: 1.4 + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -465,6 +486,7 @@ def from_predictions( y_pred, sample_weight=sample_weight, labels=labels, + pos_label=pos_label, normalize=normalize, ) From 93b793678a8f8b3097919458f92a709c8d53c622 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 16 Jul 2023 12:00:09 +0200 Subject: [PATCH 07/10] iter --- sklearn/metrics/_classification.py | 29 ++++++++++++++++++++++------ sklearn/metrics/tests/test_common.py | 2 ++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 5fc6176462fa4..7839d40ce4f64 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -326,15 +326,15 @@ def confusion_matrix( >>> (tn, fp, fn, tp) (0, 2, 1, 1) """ - if len(y_true) == 0 and len(y_pred) == 0: - # early return for empty arrays avoiding all checks - n_classes = 0 if labels is None else len(labels) - return np.zeros((n_classes, n_classes), dtype=int) - y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type not in ("binary", "multiclass"): raise ValueError("%s is not supported" % y_type) + if y_true.size == 0 and y_pred.size == 0: + # early return for empty arrays avoiding all checks + n_classes = 0 if labels is None else len(labels) + return np.zeros((n_classes, n_classes), dtype=int) + if y_type == "binary": pos_label = _check_pos_label_consistency(pos_label, y_true) elif pos_label is not None: @@ -696,7 +696,17 @@ class labels [2]_. .. [3] `Wikipedia entry for the Cohen's kappa `_. """ - confusion = confusion_matrix(y1, y2, labels=labels, sample_weight=sample_weight) + y_type, y1, y2 = _check_targets(y1, y2) + if y_type == "binary": + # we can set `pos_label` to any class labels because the computation of MCC + # is symmetric and invariant to `pos_label` switch. + pos_label = y1[0] + else: + pos_label = None + + confusion = confusion_matrix( + y1, y2, labels=labels, pos_label=pos_label, sample_weight=sample_weight + ) n_classes = confusion.shape[0] sum0 = np.sum(confusion, axis=0) sum1 = np.sum(confusion, axis=1) @@ -1942,11 +1952,18 @@ class after being classified as negative. This is the case when the f"problems, got targets of type: {y_type}" ) + if labels is None: + classes = np.unique(y_true) + pos_label = 1 if len(classes) < 2 else classes[1] + else: + pos_label = labels[-1] + cm = confusion_matrix( y_true, y_pred, sample_weight=sample_weight, labels=labels, + pos_label=pos_label, ) # Case when `y_test` contains a single class and `y_test == y_pred`. diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 6b132ccd2c37a..b0450d1d06127 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -318,6 +318,8 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # Metrics with a "pos_label" argument METRICS_WITH_POS_LABEL = { + "unnormalized_confusion_matrix", + "normalized_confusion_matrix", "roc_curve", "precision_recall_curve", "det_curve", From 7958dd5751806102215610bd7c1c669ef128e225 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 16 Jul 2023 13:53:31 +0200 Subject: [PATCH 08/10] improve coverage --- sklearn/metrics/tests/test_classification.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 0a7b816171c3a..136040222f346 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -10,7 +10,7 @@ from scipy.stats import bernoulli from sklearn import datasets, svm -from sklearn.datasets import make_multilabel_classification +from sklearn.datasets import make_classification, make_multilabel_classification from sklearn.exceptions import UndefinedMetricWarning from sklearn.metrics import ( accuracy_score, @@ -453,6 +453,13 @@ def test_precision_recall_f_unused_pos_label(): ) +def test_confusion_matrix_pos_label_error(): + _, y = make_classification(n_classes=3, n_clusters_per_class=1, random_state=0) + err_msg = "`pos_label` should only be set when the target is binary." + with pytest.raises(ValueError, match=err_msg): + confusion_matrix(y, y, pos_label=1) + + @pytest.mark.parametrize("pos_label", [0, 1]) def test_confusion_matrix_binary(pos_label): # Test confusion matrix - binary classification case From d39ef7a6aebcf3ae930d2a8f6202fb00d71d22ea Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 24 Jul 2023 12:03:23 +0200 Subject: [PATCH 09/10] Apply suggestions from code review Co-authored-by: Omar Salman --- doc/whats_new/v1.4.rst | 8 ++++---- sklearn/metrics/_classification.py | 6 +++--- sklearn/metrics/_plot/confusion_matrix.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index e595b25bcc3be..f8a695f481f9e 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -107,11 +107,11 @@ Changelog :mod:`sklearn.metrics` ...................... -- |Enhancement| |Fix| add a `pos_label` to :func:`metrics.confusion_matrix` +- |Enhancement| |Fix| Added a `pos_label` to :func:`metrics.confusion_matrix` avoiding ambiguity regarding the position of the positive class label in the - matrix. An error is raised if the positive label cannot be set to `1`. An - error is also raised if the `pos_label` is set on other classification - problem than binary. `pos_label` is also added to + matrix. An error is raised if the positive label cannot be set to `1` and + also if the `pos_label` is set on other classification + problems than binary. `pos_label` is also added to :meth:`metrics.ConfusionMatrixDisplay.from_estimator` and :meth:`metrics.ConfusionMatrixDisplay.from_predictions`. :pr:`26839` by :user:`Guillaume Lemaitre `. diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 7839d40ce4f64..9e4559e336b20 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -265,7 +265,7 @@ def confusion_matrix( The label of the positive class for binary classification. When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, `pos_label` is set to 1, otherwise an error will be raised. - An error is raised if `pos_label` is set and `y_true` is not a binary + An error is also raised if `pos_label` is set and `y_true` is not a binary classification problem. .. versionadded:: 1.4 @@ -2434,8 +2434,8 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals """ y_type, y_true, y_pred = _check_targets(y_true, y_pred) if y_type == "binary": - # We can set `pos_label` to any values since we are computing per-class - # statistics and average them. + # We can set `pos_label` to any value since we are computing per-class + # statistics and averaging them. pos_label = y_true[0] else: pos_label = None diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index a4597d5871338..6c72ba2511a39 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -243,7 +243,7 @@ def from_estimator( The label of the positive class for binary classification. When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, `pos_label` is set to 1, otherwise an error will be raised. - An error is raised if `pos_label` is set and `y_true` is not a binary + An error is also raised if `pos_label` is set and `y_true` is not a binary classification problem. .. versionadded:: 1.4 @@ -392,7 +392,7 @@ def from_predictions( The label of the positive class for binary classification. When `pos_label=None`, if `y_true` is in `{-1, 1}` or `{0, 1}`, `pos_label` is set to 1, otherwise an error will be raised. - An error is raised if `pos_label` is set and `y_true` is not a binary + An error is also raised if `pos_label` is set and `y_true` is not a binary classification problem. .. versionadded:: 1.4 From 3d6b3888a906975059b4b5418d4ebd1ce38c7392 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 24 Jul 2023 12:04:26 +0200 Subject: [PATCH 10/10] Update sklearn/metrics/_classification.py --- sklearn/metrics/_classification.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 9e4559e336b20..7ea787deda5ae 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -406,6 +406,8 @@ def confusion_matrix( cm = np.nan_to_num(cm) if pos_label is not None and pos_label != labels[-1]: + # Reorder the confusion matrix such that TP is at index + # [1, 1]. cm = cm[::-1, ::-1] return cm