-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Improve error message with invalid pos_label in plot_roc_curve and implicit pos_label in precision_recall_curve #15405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve error message with invalid pos_label in plot_roc_curve and implicit pos_label in precision_recall_curve #15405
Conversation
Note that only raises a Ideally I would also like to have diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py
index 60ab3784a..f8c0ead08 100644
--- a/sklearn/metrics/tests/test_ranking.py
+++ b/sklearn/metrics/tests/test_ranking.py
@@ -662,6 +662,14 @@ def test_auc_score_non_binary_class():
roc_auc_score(y_true, y_pred)
+def test_roc_auc_score_implicit_pos_label():
+ y_true = ["a", "b", "a"]
+ y_pred = [0.2, 0.3, 0.1]
+ err_msg = "42"
+ with pytest.raises(ValueError, match=err_msg):
+ roc_auc_score(y_true, y_pred)
+
+
def test_binary_clf_curve():
rng = check_random_state(404)
y_true = rng.randint(0, 3, size=10) However at the moment If we decide that the fix for |
@@ -532,7 +532,10 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | |||
np.array_equal(classes, [0]) or | |||
np.array_equal(classes, [-1]) or | |||
np.array_equal(classes, [1]))): | |||
raise ValueError("Data is not binary and pos_label is not specified") | |||
raise ValueError("y_true takes value in {classes} and pos_label is " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need something like the following?
if (pos_label is None and _determine_key_type(classes) == 'str'
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks indeed, I updated my PR based on your suggestion and I could get rid of the remaining numpy FutureWarning
in array_equal
.
I think Andy wants to infer pos_label automatically? I also prefer to infer pos_label since we accept a trained estimator here. (seems that Thomas also agree with Andy?) |
There's a PR but we finally decided to infer pos_label automatically. I'm involved in that decision but now I feel regretful. I think it's better to introduce pos_label parameter in roc_auc_score (which is consistent with roc_auc_score) |
I can go either way as long as we are consistent. I agree with @ogrisel that with string targets pos_label should be set and not automatically inferred. |
@@ -532,7 +532,10 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | |||
np.array_equal(classes, [0]) or | |||
np.array_equal(classes, [-1]) or | |||
np.array_equal(classes, [1]))): | |||
raise ValueError("Data is not binary and pos_label is not specified") | |||
raise ValueError("y_true takes value in {classes} and pos_label is " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the error message here need to be updated.
This will actually close Andy's issue? |
I think the best solution is to deprecate pos_label=None and use pos_label=1 as the default (this is consistent with P/R/F, jaccard, average_precision_score, we need to do deprecation in brier_score_loss, roc_auc_curve, precision_recall_curve and introduce pos_label in roc_auc_score) |
Just to remind that if this change, the 'roc_auc' scorer still needs to work in binary cases with string labels. The choice of positive class will not affect the result. |
I am not sure I understand. If you have |
The fact that you get different scores with different string labels encoding of fundamentally the same data should be considered a bug: >>> from sklearn.metrics import roc_auc_score
>>> import numpy as np
>>> y_true = np.random.choice(["red", "blue"], size=100)
... y_true2 = np.where(y_true == "red", "0 red", "1 blue")
... y_pred = np.random.randn(100)
... print(roc_auc_score(y_true, y_pred))
... print(roc_auc_score(y_true2, y_pred))
...
0.5466893039049237
0.4533106960950764 |
I mean that you can train and test a scikit-learn estimator with whichever labelling of binary classes, and using the roc scorer you will get the same result. This is not true of precision, recall, etc. |
I'm fine with this solution in plot_* in any case. I'll try review later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests erroring on some platforms btw, @ogrisel
It's true that when we use a scorer on a sklearn classifier, the implicitly inferred positive class will be consistently the same for the estimator and the scorer (the label encoder used internally in the classifier will yield consistent results with the one used internally in
I will have a look into it. |
assert_equal on str elements with numerical elements.
Co-Authored-By: Thomas J Fan <thomasjpfan@gmail.com>
Actually I think I am changing my mind, we should probably implement #15316 and just document the positive class inference in case the user passes strings. |
msg = ("make y_true take integer value in {0, 1} or {-1, 1}" | ||
" or pass pos_label explicitly.") | ||
with pytest.raises(ValueError, match=msg): | ||
precision_recall_curve(y_true, y_pred) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qinhanmin2014 I have changed the test in this PR to trigger the exception in _binary_clf_curve
from precision_recall_curve
rather then plot_roc_curve
so that we can merge both #15316 while also improving the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is still useful to:
- raise a meaningful error message when the user passes an invalid
pos_label
value; - improve the error message when passing string labels without passing an explicit
pos_label
for function that do not automatically label encode string labels (e.g. precision_recall_curve); - have more tests.
I agree. I prefer to introduce pos_label and set the default to 1. |
But I changed my mind in the context of plot_roc_curve, because in this context, the estimator already makes this implicit assumption (last class label in alphabetical order is positive) internally. So we can just introspect it and stay consistent with it. |
Actually, in the context of the Furthermore setting the default of >>> from sklearn.linear_model import LogisticRegression
>>> clf = LogisticRegression().fit([[0], [1]], [1, 2])
>>> clf.classes_
array([1, 2]) For the classifier, the positive class is |
I think I (and possibly others) were a bit confused here. It would make sense to treat 1 as the positive class, but setting However, we could provide an interface that allows the user to specify which of the classes is the semantically positive class, and Basically: if you only apply |
so if we want a quick fix, we can remove the |
This PR still had useful improvements not directly related to the |
So numpy will raise FutureWarning for things like |
Indeed: >>> import numpy as np
>>> np.array_equal(["a", "b"], [1, 2])
/home/ogrisel/miniconda3/envs/pip/lib/python3.7/site-packages/numpy/core/numeric.py:2339:
FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison
return bool(asarray(a1 == a2).all()) |
I extracted the good part and rebased on top of current master in #15562. |
This is a potential fix for the issue raised by @amueller in #15303.Edit: this PR now aims to improve the error messages in cases of invalid inputs. The actual fix for #15303 for plot_roc_curve is to be implemented in #15316.