Skip to content

Improve error message with invalid pos_label in plot_roc_curve and implicit pos_label in precision_recall_curve #15405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
11 changes: 11 additions & 0 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
if prediction_method is None:
raise ValueError('response methods not defined')

# Note: set(estimator.classes_) is required to compare str and
# int/float values without raising a numpy FutureWarning.
expected_classes = set(estimator.classes_)
if pos_label is not None and pos_label not in expected_classes:
estimator_name = estimator.__class__.__name__
raise ValueError("pos_label={} is not a valid class label for {}. "
"Expected one of {}."
.format(repr(pos_label),
estimator_name,
expected_classes))

y_pred = prediction_method(X)

if y_pred.ndim != 1:
Expand Down
10 changes: 10 additions & 0 deletions sklearn/metrics/_plot/tests/test_plot_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,13 @@ def test_plot_roc_curve(pyplot, response_method, data_binary,
assert viz.line_.get_label() == expected_label
assert viz.ax_.get_ylabel() == "True Positive Rate"
assert viz.ax_.get_xlabel() == "False Positive Rate"


def test_invalid_pos_label(pyplot, data_binary):
X, y = data_binary
lr = LogisticRegression()
lr.fit(X, y)
msg = ("pos_label='invalid' is not a valid class label for "
"LogisticRegression. Expected one of {0, 1}.")
with pytest.raises(ValueError, match=msg):
plot_roc_curve(lr, X, y, pos_label="invalid")
19 changes: 12 additions & 7 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils.multiclass import type_of_target
from ..utils.extmath import stable_cumsum
from ..utils.sparsefuncs import count_nonzero
from ..utils import _determine_key_type
from ..exceptions import UndefinedMetricWarning
from ..preprocessing import label_binarize
from ..preprocessing._label import _encode
Expand Down Expand Up @@ -526,13 +527,17 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):

# ensure binary classification if pos_label is not specified
classes = np.unique(y_true)
if (pos_label is None and
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1]))):
raise ValueError("Data is not binary and pos_label is not specified")
if (pos_label is None and (
_determine_key_type(classes) == 'str' or
not (np.array_equal(classes, [0, 1]) or
np.array_equal(classes, [-1, 1]) or
np.array_equal(classes, [0]) or
np.array_equal(classes, [-1]) or
np.array_equal(classes, [1])))):
raise ValueError("y_true takes value in {classes} and pos_label is "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need something like the following?

if (pos_label is None and _determine_key_type(classes) == 'str'
    ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks indeed, I updated my PR based on your suggestion and I could get rid of the remaining numpy FutureWarning in array_equal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the error message here need to be updated.

"not specified: either make y_true take integer "
"value in {{0, 1}} or {{-1, 1}} or pass pos_label "
"explicitly.".format(classes=set(classes)))
elif pos_label is None:
pos_label = 1.

Expand Down
11 changes: 10 additions & 1 deletion sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def test_auc_score_non_binary_class():
roc_auc_score(y_true, y_pred)


def test_binary_clf_curve():
def test_binary_clf_curve_multiclass_error():
rng = check_random_state(404)
y_true = rng.randint(0, 3, size=10)
y_pred = rng.rand(10)
Expand All @@ -671,6 +671,15 @@ def test_binary_clf_curve():
precision_recall_curve(y_true, y_pred)


def test_binary_clf_curve_implicit_pos_label():
y_true = ["a", "b"]
y_pred = [0., 1.]
msg = ("make y_true take integer value in {0, 1} or {-1, 1}"
" or pass pos_label explicitly.")
with pytest.raises(ValueError, match=msg):
precision_recall_curve(y_true, y_pred)
Copy link
Member Author

@ogrisel ogrisel Nov 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qinhanmin2014 I have changed the test in this PR to trigger the exception in _binary_clf_curve from precision_recall_curve rather then plot_roc_curve so that we can merge both #15316 while also improving the error message.

Copy link
Member Author

@ogrisel ogrisel Nov 6, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is still useful to:

  • raise a meaningful error message when the user passes an invalid pos_label value;
  • improve the error message when passing string labels without passing an explicit pos_label for function that do not automatically label encode string labels (e.g. precision_recall_curve);
  • have more tests.



def test_precision_recall_curve():
y_true, _, probas_pred = make_prediction(binary=True)
_test_precision_recall_curve(y_true, probas_pred)
Expand Down