From 26364231a4919b4f221cf4c3b7971b4bbfdd42f5 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 30 Oct 2019 15:09:30 +0100 Subject: [PATCH 1/8] Improve error message with implicit or invalid pos_label in plot_roc_curve --- sklearn/metrics/_plot/roc_curve.py | 11 ++++++++++ .../_plot/tests/test_plot_roc_curve.py | 21 +++++++++++++++++++ sklearn/metrics/_ranking.py | 5 ++++- 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 35e6e9512f105..06173da4a90c1 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -185,6 +185,17 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, if prediction_method is None: raise ValueError('response methods not defined') + if (pos_label is not None + and hasattr(estimator, "classes_") + and pos_label not in estimator.classes_): + estimator_name = estimator.__class__.__name__ + expected_classes = set(estimator.classes_) + raise ValueError("pos_label={} is not a valid class label for {}. " + "Expected one of {}." + .format(repr(pos_label), + estimator_name, + expected_classes)) + y_pred = prediction_method(X) if y_pred.ndim != 1: diff --git a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py index 1e928f55d8e73..481d5d07ebac4 100644 --- a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py +++ b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py @@ -93,3 +93,24 @@ def test_plot_roc_curve(pyplot, response_method, data_binary, assert viz.line_.get_label() == expected_label assert viz.ax_.get_ylabel() == "True Positive Rate" assert viz.ax_.get_xlabel() == "False Positive Rate" + + +def test_invalid_pos_label(data_binary): + X, y = data_binary + lr = LogisticRegression() + lr.fit(X, y) + msg = ("pos_label='invalid' is not a valid class label for " + "LogisticRegression. Expected one of {0, 1}.") + with pytest.raises(ValueError, match=msg): + plot_roc_curve(lr, X, y, pos_label="invalid") + + +def test_implicit_pos_label(data_binary): + X, y = data_binary + y = y.astype(str) + lr = LogisticRegression() + lr.fit(X, y) + msg = ("make y_true take integer value in {0, 1} or {-1, 1}" + " or pass pos_label explicitly.") + with pytest.raises(ValueError, match=msg): + plot_roc_curve(lr, X, y) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index c271781638668..a3489f7b6cd7d 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -532,7 +532,10 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): np.array_equal(classes, [0]) or np.array_equal(classes, [-1]) or np.array_equal(classes, [1]))): - raise ValueError("Data is not binary and pos_label is not specified") + raise ValueError("y_true takes value in {classes} and pos_label is not " + "specified: either make y_true take integer value in " + "{{0, 1}} or {{-1, 1}} or pass pos_label explicitly." + .format(classes=set(classes))) elif pos_label is None: pos_label = 1. From d829fa0c2973bc5c8e65a6e83eb1aaf1168a60a9 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 30 Oct 2019 15:46:49 +0100 Subject: [PATCH 2/8] Fix elementwise comparison warning --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 06173da4a90c1..20201c9f42338 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -187,7 +187,7 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, if (pos_label is not None and hasattr(estimator, "classes_") - and pos_label not in estimator.classes_): + and pos_label not in set(estimator.classes_)): estimator_name = estimator.__class__.__name__ expected_classes = set(estimator.classes_) raise ValueError("pos_label={} is not a valid class label for {}. " From ccb42bfe51b62bac43687ade9d9e552409b122b3 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 30 Oct 2019 15:48:27 +0100 Subject: [PATCH 3/8] pep8 --- sklearn/metrics/_ranking.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index a3489f7b6cd7d..c93c46c0201dc 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -532,10 +532,10 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): np.array_equal(classes, [0]) or np.array_equal(classes, [-1]) or np.array_equal(classes, [1]))): - raise ValueError("y_true takes value in {classes} and pos_label is not " - "specified: either make y_true take integer value in " - "{{0, 1}} or {{-1, 1}} or pass pos_label explicitly." - .format(classes=set(classes))) + raise ValueError("y_true takes value in {classes} and pos_label is " + "not specified: either make y_true take integer " + "value in {{0, 1}} or {{-1, 1}} or pass pos_label " + "explicitly.".format(classes=set(classes))) elif pos_label is None: pos_label = 1. From 93d497c97a2320d7720adec923d7eef6059f9091 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 30 Oct 2019 16:23:13 +0100 Subject: [PATCH 4/8] Missing pyplot fixture in test --- sklearn/metrics/_plot/tests/test_plot_roc_curve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py index 481d5d07ebac4..c12f6d328a82a 100644 --- a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py +++ b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py @@ -95,7 +95,7 @@ def test_plot_roc_curve(pyplot, response_method, data_binary, assert viz.ax_.get_xlabel() == "False Positive Rate" -def test_invalid_pos_label(data_binary): +def test_invalid_pos_label(pyplot, data_binary): X, y = data_binary lr = LogisticRegression() lr.fit(X, y) @@ -105,7 +105,7 @@ def test_invalid_pos_label(data_binary): plot_roc_curve(lr, X, y, pos_label="invalid") -def test_implicit_pos_label(data_binary): +def test_implicit_pos_label(pyplot, data_binary): X, y = data_binary y = y.astype(str) lr = LogisticRegression() From 0a74233b2cb3faef39e2fbc969e892f2eed83014 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 6 Nov 2019 11:43:30 +0100 Subject: [PATCH 5/8] Fix numpy FutureWarning assert_equal on str elements with numerical elements. --- sklearn/metrics/_ranking.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index c93c46c0201dc..ee14f6454143a 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -31,6 +31,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero +from ..utils import _determine_key_type from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..preprocessing._label import _encode @@ -526,12 +527,13 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): # ensure binary classification if pos_label is not specified classes = np.unique(y_true) - if (pos_label is None and - not (np.array_equal(classes, [0, 1]) or - np.array_equal(classes, [-1, 1]) or - np.array_equal(classes, [0]) or - np.array_equal(classes, [-1]) or - np.array_equal(classes, [1]))): + if (pos_label is None and ( + _determine_key_type(classes) == 'str' or + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1])))): raise ValueError("y_true takes value in {classes} and pos_label is " "not specified: either make y_true take integer " "value in {{0, 1}} or {{-1, 1}} or pass pos_label " From abf3cd70ae85df49c3baef32ee4e7ed5ca42a7a5 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 6 Nov 2019 11:49:56 +0100 Subject: [PATCH 6/8] Update sklearn/metrics/_plot/roc_curve.py Co-Authored-By: Thomas J Fan --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 20201c9f42338..06173da4a90c1 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -187,7 +187,7 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, if (pos_label is not None and hasattr(estimator, "classes_") - and pos_label not in set(estimator.classes_)): + and pos_label not in estimator.classes_): estimator_name = estimator.__class__.__name__ expected_classes = set(estimator.classes_) raise ValueError("pos_label={} is not a valid class label for {}. " From 0a0b9f50298214c95839cd1e94c778b4da565260 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 6 Nov 2019 12:29:10 +0100 Subject: [PATCH 7/8] test implicit pos_label message for precision_recall_curve instead of plot_roc_curve --- sklearn/metrics/_plot/roc_curve.py | 4 +++- sklearn/metrics/_plot/tests/test_plot_roc_curve.py | 11 ----------- sklearn/metrics/tests/test_ranking.py | 11 ++++++++++- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 06173da4a90c1..07b7d0f4aef1f 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -187,7 +187,9 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, if (pos_label is not None and hasattr(estimator, "classes_") - and pos_label not in estimator.classes_): + and pos_label not in set(estimator.classes_)): + # Note: set(estimator.classes_) is required to compare str and + # int/float values without raising a numpy FutureWarning. estimator_name = estimator.__class__.__name__ expected_classes = set(estimator.classes_) raise ValueError("pos_label={} is not a valid class label for {}. " diff --git a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py index c12f6d328a82a..ff90bacf41221 100644 --- a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py +++ b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py @@ -103,14 +103,3 @@ def test_invalid_pos_label(pyplot, data_binary): "LogisticRegression. Expected one of {0, 1}.") with pytest.raises(ValueError, match=msg): plot_roc_curve(lr, X, y, pos_label="invalid") - - -def test_implicit_pos_label(pyplot, data_binary): - X, y = data_binary - y = y.astype(str) - lr = LogisticRegression() - lr.fit(X, y) - msg = ("make y_true take integer value in {0, 1} or {-1, 1}" - " or pass pos_label explicitly.") - with pytest.raises(ValueError, match=msg): - plot_roc_curve(lr, X, y) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 60ab3784aceb8..fc71b6881e891 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -662,7 +662,7 @@ def test_auc_score_non_binary_class(): roc_auc_score(y_true, y_pred) -def test_binary_clf_curve(): +def test_binary_clf_curve_multiclass_error(): rng = check_random_state(404) y_true = rng.randint(0, 3, size=10) y_pred = rng.rand(10) @@ -671,6 +671,15 @@ def test_binary_clf_curve(): precision_recall_curve(y_true, y_pred) +def test_binary_clf_curve_implicit_pos_label(): + y_true = ["a", "b"] + y_pred = [0., 1.] + msg = ("make y_true take integer value in {0, 1} or {-1, 1}" + " or pass pos_label explicitly.") + with pytest.raises(ValueError, match=msg): + precision_recall_curve(y_true, y_pred) + + def test_precision_recall_curve(): y_true, _, probas_pred = make_prediction(binary=True) _test_precision_recall_curve(y_true, probas_pred) From 4bbceb61bed4943390aa269b21ba7eeca06cc1f9 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 6 Nov 2019 16:58:38 +0100 Subject: [PATCH 8/8] Simplify check for invalid pos_label --- sklearn/metrics/_plot/roc_curve.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 07b7d0f4aef1f..b6182b1a513a9 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -185,13 +185,11 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, if prediction_method is None: raise ValueError('response methods not defined') - if (pos_label is not None - and hasattr(estimator, "classes_") - and pos_label not in set(estimator.classes_)): - # Note: set(estimator.classes_) is required to compare str and - # int/float values without raising a numpy FutureWarning. + # Note: set(estimator.classes_) is required to compare str and + # int/float values without raising a numpy FutureWarning. + expected_classes = set(estimator.classes_) + if pos_label is not None and pos_label not in expected_classes: estimator_name = estimator.__class__.__name__ - expected_classes = set(estimator.classes_) raise ValueError("pos_label={} is not a valid class label for {}. " "Expected one of {}." .format(repr(pos_label),