diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index b67f5bd972c1d..8a4a148272257 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -101,7 +101,20 @@ def _check_targets(y_true, y_pred): y_true = column_or_1d(y_true) y_pred = column_or_1d(y_pred) if y_type == "binary": - unique_values = np.union1d(y_true, y_pred) + try: + unique_values = np.union1d(y_true, y_pred) + except TypeError as e: + # We expect y_true and y_pred to be of the same data type. + # If `y_true` was provided to the classifier as strings, + # `y_pred` given by the classifier will also be encoded with + # strings. So we raise a meaningful error + raise TypeError( + f"Labels in y_true and y_pred should be of the same type. " + f"Got y_true={np.unique(y_true)} and " + f"y_pred={np.unique(y_pred)}. Make sure that the " + f"predictions provided by the classifier coincides with " + f"the true labels." + ) from e if len(unique_values) > 2: y_type = "multiclass" @@ -1252,13 +1265,17 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): str(average_options)) y_type, y_true, y_pred = _check_targets(y_true, y_pred) - present_labels = unique_labels(y_true, y_pred) + # Convert to Python primitive type to avoid NumPy type / Python str + # comparison. See https://github.com/numpy/numpy/issues/6784 + present_labels = unique_labels(y_true, y_pred).tolist() if average == 'binary': if y_type == 'binary': if pos_label not in present_labels: if len(present_labels) >= 2: - raise ValueError("pos_label=%r is not a valid label: " - "%r" % (pos_label, present_labels)) + raise ValueError( + f"pos_label={pos_label} is not a valid label. It " + f"should be one of {present_labels}" + ) labels = [pos_label] else: average_options = list(average_options) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 6727de0c05c65..813c9892624ed 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -208,10 +208,14 @@ def _binary_uninterpolated_average_precision( "multilabel-indicator y_true. Do not set " "pos_label or set pos_label to 1.") elif y_type == "binary": - present_labels = np.unique(y_true) + # Convert to Python primitive type to avoid NumPy type / Python str + # comparison. See https://github.com/numpy/numpy/issues/6784 + present_labels = np.unique(y_true).tolist() if len(present_labels) == 2 and pos_label not in present_labels: - raise ValueError("pos_label=%r is invalid. Set it to a label in " - "y_true." % pos_label) + raise ValueError( + f"pos_label={pos_label} is not a valid label. It should be " + f"one of {present_labels}" + ) average_precision = partial(_binary_uninterpolated_average_precision, pos_label=pos_label) return _average_binary_score(average_precision, y_true, y_score, diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index 6677f3119dacd..1bfe5af3a7cdf 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -1247,7 +1247,7 @@ def test_multilabel_hamming_loss(): def test_jaccard_score_validation(): y_true = np.array([0, 1, 0, 1, 1]) y_pred = np.array([0, 1, 0, 1, 1]) - err_msg = r"pos_label=2 is not a valid label: array\(\[0, 1\]\)" + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" with pytest.raises(ValueError, match=err_msg): jaccard_score(y_true, y_pred, average='binary', pos_label=2) diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 24f01d46610a7..4641a7875a11d 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -1,5 +1,6 @@ from functools import partial +from inspect import signature from itertools import product from itertools import chain from itertools import permutations @@ -1412,3 +1413,67 @@ def test_thresholded_metric_permutation_invariance(name): current_score = metric(y_true_perm, y_score_perm) assert_almost_equal(score, current_score) + + +@pytest.mark.parametrize("metric_name", CLASSIFICATION_METRICS) +def test_metrics_consistent_type_error(metric_name): + # check that an understable message is raised when the type between y_true + # and y_pred mismatch + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=object) + y2 = rng.randint(0, 2, size=y1.size) + + err_msg = "Labels in y_true and y_pred should be of the same type." + with pytest.raises(TypeError, match=err_msg): + CLASSIFICATION_METRICS[metric_name](y1, y2) + + +@pytest.mark.parametrize( + "metric, y_pred_threshold", + [ + (average_precision_score, True), + # FIXME: `brier_score_loss` does not follow this convention. + # See discussion in: + # https://github.com/scikit-learn/scikit-learn/issues/18307 + pytest.param( + brier_score_loss, True, marks=pytest.mark.xfail(reason="#18307") + ), + (f1_score, False), + (partial(fbeta_score, beta=1), False), + (jaccard_score, False), + (precision_recall_curve, True), + (precision_score, False), + (recall_score, False), + (roc_curve, True), + ], +) +@pytest.mark.parametrize("dtype_y_str", [str, object]) +def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str): + # check that the error message if `pos_label` is not specified and the + # targets is made of strings. + rng = np.random.RandomState(42) + y1 = np.array(["spam"] * 3 + ["eggs"] * 2, dtype=dtype_y_str) + y2 = rng.randint(0, 2, size=y1.size) + + if not y_pred_threshold: + y2 = np.array(["spam", "eggs"], dtype=dtype_y_str)[y2] + + err_msg_pos_label_None = ( + "y_true takes value in {'eggs', 'spam'} and pos_label is not " + "specified: either make y_true take value in {0, 1} or {-1, 1} or " + "pass pos_label explicit" + ) + err_msg_pos_label_1 = ( + r"pos_label=1 is not a valid label. It should be one of " + r"\['eggs', 'spam'\]" + ) + + pos_label_default = signature(metric).parameters["pos_label"].default + + err_msg = ( + err_msg_pos_label_1 + if pos_label_default == 1 + else err_msg_pos_label_None + ) + with pytest.raises(ValueError, match=err_msg): + metric(y1, y2) diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index e08a8909cfe72..f49e469973f97 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -888,17 +888,18 @@ def test_average_precision_score_pos_label_errors(): # Raise an error when pos_label is not in binary y_true y_true = np.array([0, 1]) y_pred = np.array([0, 1]) - error_message = ("pos_label=2 is invalid. Set it to a label in y_true.") - with pytest.raises(ValueError, match=error_message): + err_msg = r"pos_label=2 is not a valid label. It should be one of \[0, 1\]" + with pytest.raises(ValueError, match=err_msg): average_precision_score(y_true, y_pred, pos_label=2) # Raise an error for multilabel-indicator y_true with # pos_label other than 1 y_true = np.array([[1, 0], [0, 1], [0, 1], [1, 0]]) y_pred = np.array([[0.9, 0.1], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8]]) - error_message = ("Parameter pos_label is fixed to 1 for multilabel" - "-indicator y_true. Do not set pos_label or set " - "pos_label to 1.") - with pytest.raises(ValueError, match=error_message): + err_msg = ( + "Parameter pos_label is fixed to 1 for multilabel-indicator y_true. " + "Do not set pos_label or set pos_label to 1." + ) + with pytest.raises(ValueError, match=err_msg): average_precision_score(y_true, y_pred, pos_label=0)