-
-
Notifications
You must be signed in to change notification settings - Fork 26k
Improve error message with invalid pos_label in plot_roc_curve and implicit pos_label in precision_recall_curve #15405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2636423
d829fa0
ccb42bf
93d497c
0a74233
abf3cd7
0a0b9f5
bbbb2aa
4bbceb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
from ..utils.multiclass import type_of_target | ||
from ..utils.extmath import stable_cumsum | ||
from ..utils.sparsefuncs import count_nonzero | ||
from ..utils import _determine_key_type | ||
from ..exceptions import UndefinedMetricWarning | ||
from ..preprocessing import label_binarize | ||
from ..preprocessing._label import _encode | ||
|
@@ -526,13 +527,17 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
|
||
# ensure binary classification if pos_label is not specified | ||
classes = np.unique(y_true) | ||
if (pos_label is None and | ||
not (np.array_equal(classes, [0, 1]) or | ||
np.array_equal(classes, [-1, 1]) or | ||
np.array_equal(classes, [0]) or | ||
np.array_equal(classes, [-1]) or | ||
np.array_equal(classes, [1]))): | ||
raise ValueError("Data is not binary and pos_label is not specified") | ||
if (pos_label is None and ( | ||
_determine_key_type(classes) == 'str' or | ||
not (np.array_equal(classes, [0, 1]) or | ||
np.array_equal(classes, [-1, 1]) or | ||
np.array_equal(classes, [0]) or | ||
np.array_equal(classes, [-1]) or | ||
np.array_equal(classes, [1])))): | ||
raise ValueError("y_true takes value in {classes} and pos_label is " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that the error message here need to be updated. |
||
"not specified: either make y_true take integer " | ||
"value in {{0, 1}} or {{-1, 1}} or pass pos_label " | ||
"explicitly.".format(classes=set(classes))) | ||
elif pos_label is None: | ||
pos_label = 1. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -662,7 +662,7 @@ def test_auc_score_non_binary_class(): | |
roc_auc_score(y_true, y_pred) | ||
|
||
|
||
def test_binary_clf_curve(): | ||
def test_binary_clf_curve_multiclass_error(): | ||
rng = check_random_state(404) | ||
y_true = rng.randint(0, 3, size=10) | ||
y_pred = rng.rand(10) | ||
|
@@ -671,6 +671,15 @@ def test_binary_clf_curve(): | |
precision_recall_curve(y_true, y_pred) | ||
|
||
|
||
def test_binary_clf_curve_implicit_pos_label(): | ||
y_true = ["a", "b"] | ||
y_pred = [0., 1.] | ||
msg = ("make y_true take integer value in {0, 1} or {-1, 1}" | ||
" or pass pos_label explicitly.") | ||
with pytest.raises(ValueError, match=msg): | ||
precision_recall_curve(y_true, y_pred) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qinhanmin2014 I have changed the test in this PR to trigger the exception in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is still useful to:
|
||
|
||
|
||
def test_precision_recall_curve(): | ||
y_true, _, probas_pred = make_prediction(binary=True) | ||
_test_precision_recall_curve(y_true, probas_pred) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need something like the following?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks indeed, I updated my PR based on your suggestion and I could get rid of the remaining numpy
FutureWarning
inarray_equal
.