diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index af08b832e9f6f..19a8327783b20 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -15,6 +15,13 @@ This is a bug-fix release to primarily resolve some packaging issues in version Changelog --------- +:mod:`sklearn.metrics` +...................... + +- |Fix| :func:`metrics.plot_confusion_matrix` now raises error when `normalize` + is invalid. Previously, it runs fine with no normalization. + :pr:`15888` by `Hanmin Qin`_. + :mod:`sklearn.utils` .................... diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index 322ac3409722f..343e63b6c0ae9 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -283,6 +283,10 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None, check_consistent_length(y_true, y_pred, sample_weight) + if normalize not in ['true', 'pred', 'all', None]: + raise ValueError("normalize must be one of {'true', 'pred', " + "'all', None}") + n_labels = labels.size label_to_ind = {y: x for x, y in enumerate(labels)} # convert yt, yp into index diff --git a/sklearn/metrics/_plot/confusion_matrix.py b/sklearn/metrics/_plot/confusion_matrix.py index be59c8dd9a847..9eec258ef69ce 100644 --- a/sklearn/metrics/_plot/confusion_matrix.py +++ b/sklearn/metrics/_plot/confusion_matrix.py @@ -184,10 +184,6 @@ def plot_confusion_matrix(estimator, X, y_true, labels=None, if not is_classifier(estimator): raise ValueError("plot_confusion_matrix only supports classifiers") - if normalize not in {'true', 'pred', 'all', None}: - raise ValueError("normalize must be one of {'true', 'pred', " - "'all', None}") - y_pred = estimator.predict(X) cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight, labels=labels, normalize=normalize) diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 4c1db4b55bb16..c33c3a829cc16 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -526,6 +526,13 @@ def test_confusion_matrix_normalize(normalize, cm_dtype, expected_results): assert cm.dtype.kind == cm_dtype +def test_confusion_matrix_normalize_wrong_option(): + y_test = [0, 0, 0, 0, 1, 1, 1, 1] + y_pred = [0, 0, 0, 0, 0, 0, 0, 0] + with pytest.raises(ValueError, match='normalize must be one of'): + confusion_matrix(y_test, y_pred, normalize=True) + + def test_confusion_matrix_normalize_single_class(): y_test = [0, 0, 0, 0, 1, 1, 1, 1] y_pred = [0, 0, 0, 0, 0, 0, 0, 0]