Skip to content

ENH add a parameter pos_label in roc_auc_score #17704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
3 changes: 3 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ Changelog
``metric='seuclidean'`` and ``X`` is not type ``np.float64``.
:pr:`15730` by :user:`Forrest Koch <ForrestCKoch>`.

- |Enhancement| Add `pos_label` parameter to :func:`roc_auc_score`.
:pr:`17594` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
23 changes: 16 additions & 7 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,16 @@ def _binary_uninterpolated_average_precision(
average, sample_weight=sample_weight)


def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None,
pos_label=None):
"""Binary roc auc score"""
if len(np.unique(y_true)) != 2:
raise ValueError("Only one class present in y_true. ROC AUC score "
"is not defined in that case.")

fpr, tpr, _ = roc_curve(y_true, y_score,
sample_weight=sample_weight)
fpr, tpr, _ = roc_curve(
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label,
)
if max_fpr is None or max_fpr == 1:
return auc(fpr, tpr)
if max_fpr <= 0 or max_fpr > 1:
Expand All @@ -248,7 +250,8 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):

@_deprecate_positional_args
def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
max_fpr=None, multi_class="raise", labels=None):
max_fpr=None, multi_class="raise", labels=None,
pos_label=None):
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
from prediction scores.

Expand Down Expand Up @@ -327,6 +330,13 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
If ``None``, the numerical or lexicographical order of the labels in
``y_true`` is used.

pos_label : int or str, default=None
The label of the positive class in the binary case. When
`pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, `pos_label` is
set to 1, otherwise an error will be raised.

.. versionadded:: 0.24

Returns
-------
auc : float
Expand Down Expand Up @@ -388,10 +398,9 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
return _multiclass_roc_auc_score(y_true, y_score, labels,
multi_class, average, sample_weight)
elif y_type == "binary":
labels = np.unique(y_true)
y_true = label_binarize(y_true, classes=labels)[:, 0]
return _average_binary_score(partial(_binary_roc_auc_score,
max_fpr=max_fpr),
max_fpr=max_fpr,
pos_label=pos_label),
y_true, y_score, average,
sample_weight=sample_weight)
else: # multilabel-indicator
Expand Down
27 changes: 23 additions & 4 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
# Arnaud Joly <arnaud.v.joly@gmail.com>
# License: Simplified BSD

from collections import Counter
from collections.abc import Iterable
from copy import deepcopy
from functools import partial
from collections import Counter

import numpy as np

Expand Down Expand Up @@ -239,7 +240,13 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
y_pred = method_caller(clf, "predict_proba", X)
if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
if "pos_label" in self._kwargs:
col_idx = np.flatnonzero(
clf.classes_ == self._kwargs["pos_label"]
)[0]
else:
col_idx = 1
y_pred = y_pred[:, col_idx]
elif y_pred.shape[1] == 1: # not multiclass
raise ValueError('got predict_proba of shape {},'
' but need classifier with two'
Expand Down Expand Up @@ -307,7 +314,19 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):

if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
if (
self._score_func.__name__ == "roc_auc_score"
and "pos_label" not in self._kwargs
Comment on lines +318 to +319
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we add the pos_label (with the mutable aspects discussed in the other PR)

I am okay with this with a symmetric property to _BaseScorer that defaults to False. This way, we can be generic and not depend on the name of the score function.

):
self._kwargs["pos_label"] = clf.classes_[1]

if "pos_label" in self._kwargs:
col_idx = np.flatnonzero(
clf.classes_ == self._kwargs["pos_label"]
)[0]
else:
col_idx = 1
y_pred = y_pred[:, col_idx]
else:
raise ValueError('got predict_proba of shape {},'
' but need classifier with two'
Expand Down Expand Up @@ -352,7 +371,7 @@ def get_scorer(scoring):
'to get valid options.' % scoring)
else:
scorer = scoring
return scorer
return deepcopy(scorer)


def _passthrough_scorer(estimator, *args, **kwargs):
Expand Down
11 changes: 11 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
# Metrics with a "pos_label" argument
METRICS_WITH_POS_LABEL = {
"roc_curve",

"roc_auc_score",
"weighted_roc_auc",
"samples_roc_auc",
"micro_roc_auc",
"ovr_roc_auc",
"weighted_ovr_roc_auc",
"ovo_roc_auc",
"weighted_ovo_roc_auc",
"partial_roc_auc",

"precision_recall_curve",

"brier_score_loss",
Expand Down
43 changes: 42 additions & 1 deletion sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from sklearn import datasets
from sklearn import svm

from sklearn.utils.extmath import softmax
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.random_projection import _sparse_random_matrix
from sklearn.utils import shuffle
from sklearn.utils.extmath import softmax
from sklearn.utils.validation import check_array, check_consistent_length
from sklearn.utils.validation import check_random_state

Expand Down Expand Up @@ -1469,3 +1473,40 @@ def test_partial_roc_auc_score():
assert_almost_equal(
roc_auc_score(y_true, y_pred, max_fpr=max_fpr),
_partial_roc_auc_score(y_true, y_pred, max_fpr))


@pytest.mark.parametrize(
"decision_method", ["predict_proba", "decision_function"]
)
def test_roc_auc_score_pos_label(decision_method):
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0,
)

classifier = LogisticRegression()
classifier.fit(X_train, y_train)

# sanity check to be sure the positive class is classes_[0] and that we
# are betrayed by the class imbalance
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
pos_label = "cancer"

y_pred = getattr(classifier, decision_method)(X_test)
y_pred = y_pred[:, 0] if y_pred.ndim == 2 else -y_pred

fpr, tpr, _ = roc_curve(y_test, y_pred, pos_label=pos_label)
roc_auc = roc_auc_score(y_test, y_pred, pos_label=pos_label)

assert roc_auc == pytest.approx(np.trapz(tpr, fpr))
100 changes: 98 additions & 2 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import shutil
import os
import numbers
from copy import deepcopy
from unittest.mock import Mock
from functools import partial

Expand All @@ -11,6 +12,7 @@
import joblib

from numpy.testing import assert_allclose
from sklearn.utils import shuffle
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import ignore_warnings
Expand All @@ -32,10 +34,12 @@
from sklearn.cluster import KMeans
from sklearn.linear_model import Ridge, LogisticRegression, Perceptron
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.datasets import load_diabetes
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import make_blobs
from sklearn.datasets import make_classification, make_regression
from sklearn.datasets import make_classification
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import load_diabetes
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
Expand Down Expand Up @@ -747,3 +751,95 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
msg = "'Perceptron' object has no attribute 'predict_proba'"
with pytest.raises(AttributeError, match=msg):
scorer(lr, X, y)


@pytest.mark.parametrize(
"scoring, is_symmetric",
[
("roc_auc", True),
("jaccard", False),
("f1", False),
("average_precision", False),
("precision", False),
("recall", False),
("neg_brier_score", True),
],
)
def test_scorer_pos_label_grid_search(scoring, is_symmetric):
# Check the behaviour for the scorer which requires a `pos_label` with
# binary target. Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/pull/17572
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)

param_grid = {"max_depth": [1, 3, 5]}
classifier = GridSearchCV(
DecisionTreeClassifier(random_state=0),
param_grid=param_grid,
scoring=scoring,
cv=2,
)

if is_symmetric:
# we will expand to compute for several scorer with different pos_label
# which should all give the same results
scorer = get_scorer(scoring)
scorer_pos_label, scorer_neg_label = deepcopy(scorer), deepcopy(scorer)
scorer_pos_label._kwargs["pos_label"] = "cancer"
scorer_neg_label._kwargs["pos_label"] = "not cancer"
multi_scoring = {
"scorer_str": scoring,
"scorer_pos": scorer_pos_label,
"scorer_neg": scorer_neg_label,
}

classifier.set_params(
scoring=multi_scoring, refit="scorer_str",
)
classifier.fit(X, y)
assert_allclose(
classifier.cv_results_["mean_test_scorer_str"],
classifier.cv_results_["mean_test_scorer_pos"]
)
assert_allclose(
classifier.cv_results_["mean_test_scorer_str"],
classifier.cv_results_["mean_test_scorer_neg"]
)
else:
with pytest.raises(ValueError):
# it should raise an error by default
classifier.fit(X, y)

# passing pos_label should always solve the issue and should be equivalent
# to encode the label with {0, 1}.

# we should control our cv indices since y will be different leading
# to different cv split
indices = np.arange(y.shape[0])
cv = [
(indices[:indices.size // 2], indices[indices.size // 2:]),
(indices[indices.size // 2:], indices[:indices.size // 2]),
]
classifier.set_params(cv=cv, scoring=scoring, refit=True)

y_encoded = (y == "cancer").astype(int)
classifier.fit(X, y_encoded)
mean_test_score_y_encoded = classifier.cv_results_["mean_test_score"]

scorer = get_scorer(scoring)
scorer._kwargs["pos_label"] = "cancer"
classifier.set_params(scoring=scorer)
classifier.fit(X, y)
mean_test_score_pos_label = classifier.cv_results_["mean_test_score"]

assert_allclose(mean_test_score_pos_label, mean_test_score_y_encoded)