Skip to content

FIX improve error message with string-encoded target in metrics #18192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,20 @@ def _check_targets(y_true, y_pred):
y_true = column_or_1d(y_true)
y_pred = column_or_1d(y_pred)
if y_type == "binary":
unique_values = np.union1d(y_true, y_pred)
try:
unique_values = np.union1d(y_true, y_pred)
except TypeError as e:
# We expect y_true and y_pred to be of the same data type.
# If `y_true` was provided to the classifier as strings,
# `y_pred` given by the classifier will also be encoded with
# strings. So we raise a meaningful error
raise TypeError(
f"Labels in y_true and y_pred should be of the same type. "
f"Got y_true={np.unique(y_true)} and "
f"y_pred={np.unique(y_pred)}. Make sure that the "
f"predictions provided by the classifier coincides with "
f"the true labels."
) from e
if len(unique_values) > 2:
y_type = "multiclass"

Expand Down Expand Up @@ -1252,13 +1265,17 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label):
str(average_options))

y_type, y_true, y_pred = _check_targets(y_true, y_pred)
present_labels = unique_labels(y_true, y_pred)
# Convert to Python primitive type to avoid NumPy type / Python str
# comparison. See https://github.com/numpy/numpy/issues/6784
present_labels = unique_labels(y_true, y_pred).tolist()
if average == 'binary':
if y_type == 'binary':
if pos_label not in present_labels:
if len(present_labels) >= 2:
raise ValueError("pos_label=%r is not a valid label: "
"%r" % (pos_label, present_labels))
raise ValueError(
f"pos_label={pos_label} is not a valid label. It "
f"should be one of {present_labels}"
)
labels = [pos_label]
else:
average_options = list(average_options)
Expand Down
10 changes: 7 additions & 3 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,14 @@ def _binary_uninterpolated_average_precision(
"multilabel-indicator y_true. Do not set "
"pos_label or set pos_label to 1.")
elif y_type == "binary":
present_labels = np.unique(y_true)
# Convert to Python primitive type to avoid NumPy type / Python str
# comparison. See https://github.com/numpy/numpy/issues/6784
present_labels = np.unique(y_true).tolist()
if len(present_labels) == 2 and pos_label not in present_labels:
raise ValueError("pos_label=%r is invalid. Set it to a label in "
"y_true." % pos_label)
raise ValueError(
f"pos_label={pos_label} is not a valid label. It should be "
f"one of {present_labels}"
)
average_precision = partial(_binary_uninterpolated_average_precision,
pos_label=pos_label)
return _average_binary_score(average_precision, y_true, y_score,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,7 @@ def test_multilabel_hamming_loss():
def test_jaccard_score_validation():
y_true = np.array([0, 1, 0, 1, 1])
y_pred = np.array([0, 1, 0, 1, 1])
err_msg = r"pos_label=2 is not a valid label: array\(\[0, 1\]\)"
err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]"
with pytest.raises(ValueError, match=err_msg):
jaccard_score(y_true, y_pred, average='binary', pos_label=2)

Expand Down
65 changes: 65 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from functools import partial
from inspect import signature
from itertools import product
from itertools import chain
from itertools import permutations
Expand Down Expand Up @@ -1412,3 +1413,67 @@ def test_thresholded_metric_permutation_invariance(name):

current_score = metric(y_true_perm, y_score_perm)
assert_almost_equal(score, current_score)


@pytest.mark.parametrize("metric_name", CLASSIFICATION_METRICS)
def test_metrics_consistent_type_error(metric_name):
# check that an understable message is raised when the type between y_true
# and y_pred mismatch
rng = np.random.RandomState(42)
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object)
y2 = rng.randint(0, 2, size=y1.size)

err_msg = "Labels in y_true and y_pred should be of the same type."
with pytest.raises(TypeError, match=err_msg):
CLASSIFICATION_METRICS[metric_name](y1, y2)


@pytest.mark.parametrize(
"metric, y_pred_threshold",
[
(average_precision_score, True),
# FIXME: `brier_score_loss` does not follow this convention.
# See discussion in:
# https://github.com/scikit-learn/scikit-learn/issues/18307
pytest.param(
brier_score_loss, True, marks=pytest.mark.xfail(reason="#18307")
),
(f1_score, False),
(partial(fbeta_score, beta=1), False),
(jaccard_score, False),
(precision_recall_curve, True),
(precision_score, False),
(recall_score, False),
(roc_curve, True),
],
)
@pytest.mark.parametrize("dtype_y_str", [str, object])
def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
# check that the error message if `pos_label` is not specified and the
# targets is made of strings.
rng = np.random.RandomState(42)
y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str)
y2 = rng.randint(0, 2, size=y1.size)

if not y_pred_threshold:
y2 = np.array(["spam", "eggs"], dtype=dtype_y_str)[y2]

err_msg_pos_label_None = (
"y_true takes value in {'eggs', 'spam'} and pos_label is not "
"specified: either make y_true take value in {0, 1} or {-1, 1} or "
"pass pos_label explicit"
)
err_msg_pos_label_1 = (
r"pos_label=1 is not a valid label. It should be one of "
r"\['eggs', 'spam'\]"
)

pos_label_default = signature(metric).parameters["pos_label"].default

err_msg = (
err_msg_pos_label_1
if pos_label_default == 1
else err_msg_pos_label_None
)
with pytest.raises(ValueError, match=err_msg):
metric(y1, y2)
13 changes: 7 additions & 6 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,17 +888,18 @@ def test_average_precision_score_pos_label_errors():
# Raise an error when pos_label is not in binary y_true
y_true = np.array([0, 1])
y_pred = np.array([0, 1])
error_message = ("pos_label=2 is invalid. Set it to a label in y_true.")
with pytest.raises(ValueError, match=error_message):
err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]"
with pytest.raises(ValueError, match=err_msg):
average_precision_score(y_true, y_pred, pos_label=2)
# Raise an error for multilabel-indicator y_true with
# pos_label other than 1
y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]])
y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]])
error_message = ("Parameter pos_label is fixed to 1 for multilabel"
"-indicator y_true. Do not set pos_label or set "
"pos_label to 1.")
with pytest.raises(ValueError, match=error_message):
err_msg = (
"Parameter pos_label is fixed to 1 for multilabel-indicator y_true. "
"Do not set pos_label or set pos_label to 1."
)
with pytest.raises(ValueError, match=err_msg):
average_precision_score(y_true, y_pred, pos_label=0)


Expand Down