diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index da366c913f500..42a9382e2d7b7 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -240,6 +240,9 @@ Changelog ``metric='seuclidean'`` and ``X`` is not type ``np.float64``. :pr:`15730` by :user:`Forrest Koch `. +- |Enhancement| Add `pos_label` parameter to :func:`roc_auc_score`. + :pr:`17594` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index e07f61a92d478..5f738f01268d2 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -218,14 +218,16 @@ def _binary_uninterpolated_average_precision( average, sample_weight=sample_weight) -def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None): +def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None, + pos_label=None): """Binary roc auc score""" if len(np.unique(y_true)) != 2: raise ValueError("Only one class present in y_true. ROC AUC score " "is not defined in that case.") - fpr, tpr, _ = roc_curve(y_true, y_score, - sample_weight=sample_weight) + fpr, tpr, _ = roc_curve( + y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, + ) if max_fpr is None or max_fpr == 1: return auc(fpr, tpr) if max_fpr <= 0 or max_fpr > 1: @@ -248,7 +250,8 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None): @_deprecate_positional_args def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None, - max_fpr=None, multi_class="raise", labels=None): + max_fpr=None, multi_class="raise", labels=None, + pos_label=None): """Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores. @@ -327,6 +330,13 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None, If ``None``, the numerical or lexicographical order of the labels in ``y_true`` is used. + pos_label : int or str, default=None + The label of the positive class in the binary case. When + `pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, `pos_label` is + set to 1, otherwise an error will be raised. + + .. versionadded:: 0.24 + Returns ------- auc : float @@ -388,10 +398,9 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None, return _multiclass_roc_auc_score(y_true, y_score, labels, multi_class, average, sample_weight) elif y_type == "binary": - labels = np.unique(y_true) - y_true = label_binarize(y_true, classes=labels)[:, 0] return _average_binary_score(partial(_binary_roc_auc_score, - max_fpr=max_fpr), + max_fpr=max_fpr, + pos_label=pos_label), y_true, y_score, average, sample_weight=sample_weight) else: # multilabel-indicator diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index b824b9b0cbcb8..b40ad6c23db69 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -18,9 +18,10 @@ # Arnaud Joly # License: Simplified BSD +from collections import Counter from collections.abc import Iterable +from copy import deepcopy from functools import partial -from collections import Counter import numpy as np @@ -239,7 +240,13 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): y_pred = method_caller(clf, "predict_proba", X) if y_type == "binary": if y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] + if "pos_label" in self._kwargs: + col_idx = np.flatnonzero( + clf.classes_ == self._kwargs["pos_label"] + )[0] + else: + col_idx = 1 + y_pred = y_pred[:, col_idx] elif y_pred.shape[1] == 1: # not multiclass raise ValueError('got predict_proba of shape {},' ' but need classifier with two' @@ -307,7 +314,19 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): if y_type == "binary": if y_pred.shape[1] == 2: - y_pred = y_pred[:, 1] + if ( + self._score_func.__name__ == "roc_auc_score" + and "pos_label" not in self._kwargs + ): + self._kwargs["pos_label"] = clf.classes_[1] + + if "pos_label" in self._kwargs: + col_idx = np.flatnonzero( + clf.classes_ == self._kwargs["pos_label"] + )[0] + else: + col_idx = 1 + y_pred = y_pred[:, col_idx] else: raise ValueError('got predict_proba of shape {},' ' but need classifier with two' @@ -352,7 +371,7 @@ def get_scorer(scoring): 'to get valid options.' % scoring) else: scorer = scoring - return scorer + return deepcopy(scorer) def _passthrough_scorer(estimator, *args, **kwargs): diff --git a/sklearn/metrics/tests/test_common.py b/sklearn/metrics/tests/test_common.py index 3f2ba83b474c7..48015c71055a6 100644 --- a/sklearn/metrics/tests/test_common.py +++ b/sklearn/metrics/tests/test_common.py @@ -321,6 +321,17 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs): # Metrics with a "pos_label" argument METRICS_WITH_POS_LABEL = { "roc_curve", + + "roc_auc_score", + "weighted_roc_auc", + "samples_roc_auc", + "micro_roc_auc", + "ovr_roc_auc", + "weighted_ovr_roc_auc", + "ovo_roc_auc", + "weighted_ovo_roc_auc", + "partial_roc_auc", + "precision_recall_curve", "brier_score_loss", diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index 3daafa8d196d3..dd771570e8481 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -7,9 +7,13 @@ from sklearn import datasets from sklearn import svm -from sklearn.utils.extmath import softmax from sklearn.datasets import make_multilabel_classification +from sklearn.datasets import load_breast_cancer +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split from sklearn.random_projection import _sparse_random_matrix +from sklearn.utils import shuffle +from sklearn.utils.extmath import softmax from sklearn.utils.validation import check_array, check_consistent_length from sklearn.utils.validation import check_random_state @@ -1469,3 +1473,40 @@ def test_partial_roc_auc_score(): assert_almost_equal( roc_auc_score(y_true, y_pred, max_fpr=max_fpr), _partial_roc_auc_score(y_true, y_pred, max_fpr)) + + +@pytest.mark.parametrize( + "decision_method", ["predict_proba", "decision_function"] +) +def test_roc_auc_score_pos_label(decision_method): + X, y = load_breast_cancer(return_X_y=True) + # create an highly imbalanced + idx_positive = np.flatnonzero(y == 1) + idx_negative = np.flatnonzero(y == 0) + idx_selected = np.hstack([idx_negative, idx_positive[:25]]) + X, y = X[idx_selected], y[idx_selected] + X, y = shuffle(X, y, random_state=42) + # only use 2 features to make the problem even harder + X = X[:, :2] + y = np.array( + ["cancer" if c == 1 else "not cancer" for c in y], dtype=object + ) + X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=0, + ) + + classifier = LogisticRegression() + classifier.fit(X_train, y_train) + + # sanity check to be sure the positive class is classes_[0] and that we + # are betrayed by the class imbalance + assert classifier.classes_.tolist() == ["cancer", "not cancer"] + pos_label = "cancer" + + y_pred = getattr(classifier, decision_method)(X_test) + y_pred = y_pred[:, 0] if y_pred.ndim == 2 else -y_pred + + fpr, tpr, _ = roc_curve(y_test, y_pred, pos_label=pos_label) + roc_auc = roc_auc_score(y_test, y_pred, pos_label=pos_label) + + assert roc_auc == pytest.approx(np.trapz(tpr, fpr)) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 67900b7cb77c3..48b01e638ddc0 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -3,6 +3,7 @@ import shutil import os import numbers +from copy import deepcopy from unittest.mock import Mock from functools import partial @@ -11,6 +12,7 @@ import joblib from numpy.testing import assert_allclose +from sklearn.utils import shuffle from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import ignore_warnings @@ -32,10 +34,12 @@ from sklearn.cluster import KMeans from sklearn.linear_model import Ridge, LogisticRegression, Perceptron from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.datasets import load_diabetes +from sklearn.datasets import load_breast_cancer from sklearn.datasets import make_blobs -from sklearn.datasets import make_classification, make_regression +from sklearn.datasets import make_classification from sklearn.datasets import make_multilabel_classification -from sklearn.datasets import load_diabetes +from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split, cross_val_score from sklearn.model_selection import GridSearchCV from sklearn.multiclass import OneVsRestClassifier @@ -747,3 +751,95 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name): msg = "'Perceptron' object has no attribute 'predict_proba'" with pytest.raises(AttributeError, match=msg): scorer(lr, X, y) + + +@pytest.mark.parametrize( + "scoring, is_symmetric", + [ + ("roc_auc", True), + ("jaccard", False), + ("f1", False), + ("average_precision", False), + ("precision", False), + ("recall", False), + ("neg_brier_score", True), + ], +) +def test_scorer_pos_label_grid_search(scoring, is_symmetric): + # Check the behaviour for the scorer which requires a `pos_label` with + # binary target. Non-regression test for: + # https://github.com/scikit-learn/scikit-learn/pull/17572 + X, y = load_breast_cancer(return_X_y=True) + # create an highly imbalanced + idx_positive = np.flatnonzero(y == 1) + idx_negative = np.flatnonzero(y == 0) + idx_selected = np.hstack([idx_negative, idx_positive[:25]]) + X, y = X[idx_selected], y[idx_selected] + X, y = shuffle(X, y, random_state=42) + # only use 2 features to make the problem even harder + X = X[:, :2] + y = np.array( + ["cancer" if c == 1 else "not cancer" for c in y], dtype=object + ) + + param_grid = {"max_depth": [1, 3, 5]} + classifier = GridSearchCV( + DecisionTreeClassifier(random_state=0), + param_grid=param_grid, + scoring=scoring, + cv=2, + ) + + if is_symmetric: + # we will expand to compute for several scorer with different pos_label + # which should all give the same results + scorer = get_scorer(scoring) + scorer_pos_label, scorer_neg_label = deepcopy(scorer), deepcopy(scorer) + scorer_pos_label._kwargs["pos_label"] = "cancer" + scorer_neg_label._kwargs["pos_label"] = "not cancer" + multi_scoring = { + "scorer_str": scoring, + "scorer_pos": scorer_pos_label, + "scorer_neg": scorer_neg_label, + } + + classifier.set_params( + scoring=multi_scoring, refit="scorer_str", + ) + classifier.fit(X, y) + assert_allclose( + classifier.cv_results_["mean_test_scorer_str"], + classifier.cv_results_["mean_test_scorer_pos"] + ) + assert_allclose( + classifier.cv_results_["mean_test_scorer_str"], + classifier.cv_results_["mean_test_scorer_neg"] + ) + else: + with pytest.raises(ValueError): + # it should raise an error by default + classifier.fit(X, y) + + # passing pos_label should always solve the issue and should be equivalent + # to encode the label with {0, 1}. + + # we should control our cv indices since y will be different leading + # to different cv split + indices = np.arange(y.shape[0]) + cv = [ + (indices[:indices.size // 2], indices[indices.size // 2:]), + (indices[indices.size // 2:], indices[:indices.size // 2]), + ] + classifier.set_params(cv=cv, scoring=scoring, refit=True) + + y_encoded = (y == "cancer").astype(int) + classifier.fit(X, y_encoded) + mean_test_score_y_encoded = classifier.cv_results_["mean_test_score"] + + scorer = get_scorer(scoring) + scorer._kwargs["pos_label"] = "cancer" + classifier.set_params(scoring=scorer) + classifier.fit(X, y) + mean_test_score_pos_label = classifier.cv_results_["mean_test_score"] + + assert_allclose(mean_test_score_pos_label, mean_test_score_y_encoded)