diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 35e6e9512f105..b6182b1a513a9 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -185,6 +185,17 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None, if prediction_method is None: raise ValueError('response methods not defined') + # Note: set(estimator.classes_) is required to compare str and + # int/float values without raising a numpy FutureWarning. + expected_classes = set(estimator.classes_) + if pos_label is not None and pos_label not in expected_classes: + estimator_name = estimator.__class__.__name__ + raise ValueError("pos_label={} is not a valid class label for {}. " + "Expected one of {}." + .format(repr(pos_label), + estimator_name, + expected_classes)) + y_pred = prediction_method(X) if y_pred.ndim != 1: diff --git a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py index 1e928f55d8e73..ff90bacf41221 100644 --- a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py +++ b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py @@ -93,3 +93,13 @@ def test_plot_roc_curve(pyplot, response_method, data_binary, assert viz.line_.get_label() == expected_label assert viz.ax_.get_ylabel() == "True Positive Rate" assert viz.ax_.get_xlabel() == "False Positive Rate" + + +def test_invalid_pos_label(pyplot, data_binary): + X, y = data_binary + lr = LogisticRegression() + lr.fit(X, y) + msg = ("pos_label='invalid' is not a valid class label for " + "LogisticRegression. Expected one of {0, 1}.") + with pytest.raises(ValueError, match=msg): + plot_roc_curve(lr, X, y, pos_label="invalid") diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index c271781638668..ee14f6454143a 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -31,6 +31,7 @@ from ..utils.multiclass import type_of_target from ..utils.extmath import stable_cumsum from ..utils.sparsefuncs import count_nonzero +from ..utils import _determine_key_type from ..exceptions import UndefinedMetricWarning from ..preprocessing import label_binarize from ..preprocessing._label import _encode @@ -526,13 +527,17 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): # ensure binary classification if pos_label is not specified classes = np.unique(y_true) - if (pos_label is None and - not (np.array_equal(classes, [0, 1]) or - np.array_equal(classes, [-1, 1]) or - np.array_equal(classes, [0]) or - np.array_equal(classes, [-1]) or - np.array_equal(classes, [1]))): - raise ValueError("Data is not binary and pos_label is not specified") + if (pos_label is None and ( + _determine_key_type(classes) == 'str' or + not (np.array_equal(classes, [0, 1]) or + np.array_equal(classes, [-1, 1]) or + np.array_equal(classes, [0]) or + np.array_equal(classes, [-1]) or + np.array_equal(classes, [1])))): + raise ValueError("y_true takes value in {classes} and pos_label is " + "not specified: either make y_true take integer " + "value in {{0, 1}} or {{-1, 1}} or pass pos_label " + "explicitly.".format(classes=set(classes))) elif pos_label is None: pos_label = 1. diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 0275a26055915..18c9444184231 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -662,7 +662,7 @@ def test_auc_score_non_binary_class(): roc_auc_score(y_true, y_pred) -def test_binary_clf_curve(): +def test_binary_clf_curve_multiclass_error(): rng = check_random_state(404) y_true = rng.randint(0, 3, size=10) y_pred = rng.rand(10) @@ -671,6 +671,15 @@ def test_binary_clf_curve(): precision_recall_curve(y_true, y_pred) +def test_binary_clf_curve_implicit_pos_label(): + y_true = ["a", "b"] + y_pred = [0., 1.] + msg = ("make y_true take integer value in {0, 1} or {-1, 1}" + " or pass pos_label explicitly.") + with pytest.raises(ValueError, match=msg): + precision_recall_curve(y_true, y_pred) + + def test_precision_recall_curve(): y_true, _, probas_pred = make_prediction(binary=True) _test_precision_recall_curve(y_true, probas_pred)