From 84309211ec246bf09f416d27fcf812685bd24d49 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 14 Dec 2019 19:40:04 +0800 Subject: [PATCH 1/4] MNT Raise erorr when normalize is invalid in confusion_matrix --- sklearn/metrics/_classification.py | 4 ++++ sklearn/metrics/tests/test_classification.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 322ac3409722f..6e6bc839dc95d 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -283,6 +283,10 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, check_consistent_length(y_true, y_pred, sample_weight) + if normalize not in {'true', 'pred', 'all', None}: + raise ValueError("normalize must be one of {'true', 'pred', " + "'all', None}") + n_labels = labels.size label_to_ind = {y: x for x, y in enumerate(labels)} # convert yt, yp into index diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 4c1db4b55bb16..c33c3a829cc16 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -526,6 +526,13 @@ def test_confusion_matrix_normalize(normalize, cm_dtype, expected_results): assert cm.dtype.kind == cm_dtype +def test_confusion_matrix_normalize_wrong_option(): + y_test = [0, 0, 0, 0, 1, 1, 1, 1] + y_pred = [0, 0, 0, 0, 0, 0, 0, 0] + with pytest.raises(ValueError, match='normalize must be one of'): + confusion_matrix(y_test, y_pred, normalize=True) + + def test_confusion_matrix_normalize_single_class(): y_test = [0, 0, 0, 0, 1, 1, 1, 1] y_pred = [0, 0, 0, 0, 0, 0, 0, 0] From 465fa64ded457ae3eca641ae9086c119d8f35863 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 14 Dec 2019 19:41:32 +0800 Subject: [PATCH 2/4] remove --- sklearn/metrics/_plot/confusion_matrix.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index be59c8dd9a847..9eec258ef69ce 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -184,10 +184,6 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None, if not is_classifier(estimator): raise ValueError("plot_confusion_matrix only supports classifiers") - if normalize not in {'true', 'pred', 'all', None}: - raise ValueError("normalize must be one of {'true', 'pred', " - "'all', None}") - y_pred = estimator.predict(X) cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight, labels=labels, normalize=normalize) From 39eedac9764ba50a0cc9d348acef444bc2304e14 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sat, 14 Dec 2019 21:40:19 +0800 Subject: [PATCH 3/4] Update sklearn/metrics/_classification.py Co-Authored-By: Roman Yurchak --- sklearn/metrics/_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 6e6bc839dc95d..343e63b6c0ae9 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -283,7 +283,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, check_consistent_length(y_true, y_pred, sample_weight) - if normalize not in {'true', 'pred', 'all', None}: + if normalize not in ['true', 'pred', 'all', None]: raise ValueError("normalize must be one of {'true', 'pred', " "'all', None}") From d531f7347143d650e8a720719f2f6259b0200978 Mon Sep 17 00:00:00 2001 From: Hanmin Qin Date: Sun, 15 Dec 2019 10:34:30 +0800 Subject: [PATCH 4/4] whats new --- doc/whats_new/v0.22.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index af08b832e9f6f..19a8327783b20 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -15,6 +15,13 @@ This is a bug-fix release to primarily resolve some packaging issues in version Changelog --------- +:mod:`sklearn.metrics` +...................... + +- |Fix| :func:`metrics.plot_confusion_matrix` now raises error when `normalize` + is invalid. Previously, it runs fine with no normalization. + :pr:`15888` by `Hanmin Qin`_. + :mod:`sklearn.utils` ....................