Skip to content

FIX make it possible to specify the positive label in roc_auc_score #18107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,17 @@ def _binary_uninterpolated_average_precision(
average, sample_weight=sample_weight)


def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
def _binary_roc_auc_score(
y_true, y_score, sample_weight=None, max_fpr=None, pos_label=None,
):
"""Binary roc auc score."""
if len(np.unique(y_true)) != 2:
raise ValueError("Only one class present in y_true. ROC AUC score "
"is not defined in that case.")

fpr, tpr, _ = roc_curve(y_true, y_score,
sample_weight=sample_weight)
fpr, tpr, _ = roc_curve(
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label
)
if max_fpr is None or max_fpr == 1:
return auc(fpr, tpr)
if max_fpr <= 0 or max_fpr > 1:
Expand All @@ -248,7 +251,8 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):

@_deprecate_positional_args
def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
max_fpr=None, multi_class="raise", labels=None):
max_fpr=None, multi_class="raise", labels=None,
pos_label=None):
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
from prediction scores.

Expand Down Expand Up @@ -388,10 +392,9 @@ def roc_auc_score(y_true, y_score, *, average="macro", sample_weight=None,
return _multiclass_roc_auc_score(y_true, y_score, labels,
multi_class, average, sample_weight)
elif y_type == "binary":
labels = np.unique(y_true)
y_true = label_binarize(y_true, classes=labels)[:, 0]
return _average_binary_score(partial(_binary_roc_auc_score,
max_fpr=max_fpr),
max_fpr=max_fpr,
pos_label=pos_label),
y_true, y_score, average,
sample_weight=sample_weight)
else: # multilabel-indicator
Expand Down
101 changes: 72 additions & 29 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
# Arnaud Joly <arnaud.v.joly@gmail.com>
# License: Simplified BSD

from collections import Counter
from collections.abc import Iterable
from copy import deepcopy
from functools import partial
from collections import Counter
from inspect import signature

import numpy as np

Expand Down Expand Up @@ -122,10 +124,11 @@ def _use_cache(self, estimator):


class _BaseScorer:
def __init__(self, score_func, sign, kwargs):
def __init__(self, score_func, sign, is_symmetric, kwargs):
self._kwargs = kwargs
self._score_func = score_func
self._sign = sign
self._is_symmetric = is_symmetric

def __repr__(self):
kwargs_string = "".join([", %s=%s" % (str(k), str(v))
Expand Down Expand Up @@ -237,18 +240,24 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):

y_type = type_of_target(y)
y_pred = method_caller(clf, "predict_proba", X)

if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
col_idx = np.flatnonzero(
clf.classes_ == self._kwargs.get(
"pos_label", clf.classes_[1]
)
)[0]
y_pred = y_pred[:, col_idx]
elif y_pred.shape[1] == 1: # not multiclass
raise ValueError('got predict_proba of shape {},'
' but need classifier with two'
' classes for {} scoring'.format(
y_pred.shape, self._score_func.__name__))
if sample_weight is not None:
return self._sign * self._score_func(y, y_pred,
sample_weight=sample_weight,
**self._kwargs)
return self._sign * self._score_func(
y, y_pred, sample_weight=sample_weight, **self._kwargs
)
else:
return self._sign * self._score_func(y, y_pred, **self._kwargs)

Expand Down Expand Up @@ -292,22 +301,42 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
if y_type not in ("binary", "multilabel-indicator"):
raise ValueError("{0} format is not supported".format(y_type))

kwargs = deepcopy(self._kwargs)
params_score_func = signature(self._score_func).parameters
if (
self._is_symmetric
and "pos_label" in params_score_func
and "pos_label" not in kwargs
):
kwargs["pos_label"] = clf.classes_[1]

if is_regressor(clf):
y_pred = method_caller(clf, "predict", X)
else:
try:
y_pred = method_caller(clf, "decision_function", X)

# For multi-output multi-class estimator
if isinstance(y_pred, list):
# For multi-output multi-class estimator
y_pred = np.vstack([p for p in y_pred]).T
elif (
y_type == "binary"
and "pos_label" in kwargs
and kwargs["pos_label"] == clf.classes_[0]
):
y_pred *= -1

except (NotImplementedError, AttributeError):
y_pred = method_caller(clf, "predict_proba", X)

if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
col_idx = np.flatnonzero(
clf.classes_ == kwargs.get(
"pos_label", clf.classes_[1]
)
)[0]
y_pred = y_pred[:, col_idx]
else:
raise ValueError('got predict_proba of shape {},'
' but need classifier with two'
Expand All @@ -318,11 +347,11 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
y_pred = np.vstack([p[:, -1] for p in y_pred]).T

if sample_weight is not None:
return self._sign * self._score_func(y, y_pred,
sample_weight=sample_weight,
**self._kwargs)
return self._sign * self._score_func(
y, y_pred, sample_weight=sample_weight, **kwargs
)
else:
return self._sign * self._score_func(y, y_pred, **self._kwargs)
return self._sign * self._score_func(y, y_pred, **kwargs)

def _factory_args(self):
return ", needs_threshold=True"
Expand Down Expand Up @@ -494,8 +523,15 @@ def _check_multimetric_scoring(estimator, scoring):


@_deprecate_positional_args
def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
needs_threshold=False, **kwargs):
def make_scorer(
score_func,
*,
greater_is_better=True,
needs_proba=False,
needs_threshold=False,
is_symmetric=False,
**kwargs,
):
"""Make a scorer from a performance metric or loss function.

This factory function wraps scoring functions for use in GridSearchCV
Expand Down Expand Up @@ -575,7 +611,7 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
cls = _ThresholdScorer
else:
cls = _PredictScorer
return cls(score_func, sign, kwargs)
return cls(score_func, sign, is_symmetric, kwargs)


# Standard regression scores
Expand Down Expand Up @@ -610,20 +646,27 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)

# Score functions that need decision values
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
needs_threshold=True)
average_precision_scorer = make_scorer(average_precision_score,
needs_threshold=True)
roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_proba=True,
multi_class='ovo')
roc_auc_ovo_weighted_scorer = make_scorer(roc_auc_score, needs_proba=True,
multi_class='ovo',
average='weighted')
roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_proba=True,
multi_class='ovr')
roc_auc_ovr_weighted_scorer = make_scorer(roc_auc_score, needs_proba=True,
multi_class='ovr',
average='weighted')
roc_auc_scorer = make_scorer(
roc_auc_score,
greater_is_better=True,
needs_threshold=True,
is_symmetric=True,
)
average_precision_scorer = make_scorer(
average_precision_score, needs_threshold=True
)
roc_auc_ovo_scorer = make_scorer(
roc_auc_score, needs_proba=True, multi_class="ovo"
)
roc_auc_ovo_weighted_scorer = make_scorer(
roc_auc_score, needs_proba=True, multi_class="ovo", average="weighted"
)
roc_auc_ovr_scorer = make_scorer(
roc_auc_score, needs_proba=True, multi_class="ovr"
)
roc_auc_ovr_weighted_scorer = make_scorer(
roc_auc_score, needs_proba=True, multi_class="ovr", average="weighted"
)

# Score function for probabilistic classification
neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False,
Expand Down
51 changes: 50 additions & 1 deletion sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@
from sklearn import datasets
from sklearn import svm

from sklearn.utils.extmath import softmax
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import load_breast_cancer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.random_projection import _sparse_random_matrix
from sklearn.utils import shuffle
from sklearn.utils.extmath import softmax
from sklearn.utils.validation import check_array, check_consistent_length
from sklearn.utils.validation import check_random_state

Expand Down Expand Up @@ -1469,3 +1474,47 @@ def test_partial_roc_auc_score():
assert_almost_equal(
roc_auc_score(y_true, y_pred, max_fpr=max_fpr),
_partial_roc_auc_score(y_true, y_pred, max_fpr))


@pytest.mark.parametrize(
"decision_method", ["predict_proba", "decision_function"]
)
def test_roc_auc_score_pos_label(decision_method):
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0,
)

classifier = LogisticRegression()
classifier.fit(X_train, y_train)

# sanity check to be sure the positive class is classes_[0] and that we
# are betrayed by the class imbalance
assert classifier.classes_.tolist() == ["cancer", "not cancer"]
pos_label = "cancer"

y_pred = getattr(classifier, decision_method)(X_test)
y_pred = y_pred[:, 0] if y_pred.ndim == 2 else -y_pred

fpr, tpr, _ = roc_curve(y_test, y_pred, pos_label=pos_label)
roc_auc = roc_auc_score(y_test, y_pred, pos_label=pos_label)

assert roc_auc == pytest.approx(np.trapz(tpr, fpr))

param_grid = {"C": [0.1, 1]}
grid_search = GridSearchCV(
classifier, param_grid=param_grid, scoring="roc_auc"
)
grid_search.fit(X, y)
print(grid_search.cv_results_)
79 changes: 79 additions & 0 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numbers
from unittest.mock import Mock
from functools import partial
from _pytest.python_api import approx

import numpy as np
import pytest
Expand Down Expand Up @@ -35,10 +36,12 @@
from sklearn.datasets import make_blobs
from sklearn.datasets import make_classification, make_regression
from sklearn.datasets import make_multilabel_classification
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.utils import shuffle


REGRESSION_SCORERS = ['explained_variance', 'r2',
Expand Down Expand Up @@ -747,3 +750,79 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
msg = "'Perceptron' object has no attribute 'predict_proba'"
with pytest.raises(AttributeError, match=msg):
scorer(lr, X, y)


def _make_imbalanced_string_dataset():
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0,
)
return X_train, X_test, y_train, y_test


def test_average_precision_pos_label():
from sklearn.metrics import average_precision_score
X_train, X_test, y_train, y_test = _make_imbalanced_string_dataset()

classifier = LogisticRegression().fit(X_train, y_train)
y_proba = classifier.predict_proba(X_test)
y_decision_function = classifier.decision_function(X_test)

pos_label = "cancer"
y_proba = y_proba[:, 0]
y_decision_function *= -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a line such as:

assert classifier.classes_[0] == pos_label

to make the test easier to follow.


ap_proba = average_precision_score(y_test, y_proba, pos_label=pos_label)
ap_decision_function = average_precision_score(
y_test, y_decision_function, pos_label=pos_label
)
assert ap_proba == pytest.approx(ap_decision_function)

average_precision_scorer = make_scorer(
average_precision_score, needs_threshold=True,
)
with pytest.raises(ValueError):
average_precision_scorer(classifier, X_test, y_test)
Copy link
Member

@ogrisel ogrisel Aug 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also check the message to make the test easier to follow.


average_precision_scorer = make_scorer(
average_precision_score, needs_threshold=True, pos_label=pos_label
)
ap_scorer = average_precision_scorer(classifier, X_test, y_test)

assert ap_scorer == pytest.approx(ap_proba)


def test_roc_auc_pos_label():
from sklearn.metrics import roc_auc_score
X_train, X_test, y_train, y_test = _make_imbalanced_string_dataset()

classifier = LogisticRegression().fit(X_train, y_train)
y_proba = classifier.predict_proba(X_test)
y_decision_function = classifier.decision_function(X_test)

pos_label = "cancer"
y_proba = y_proba[:, 0]
y_decision_function *= -1

ap_proba = roc_auc_score(y_test, y_proba, pos_label=pos_label)
ap_decision_function = roc_auc_score(
y_test, y_decision_function, pos_label=pos_label
)
assert ap_proba == pytest.approx(ap_decision_function)

roc_auc_scorer = make_scorer(
roc_auc_score, needs_threshold=True, is_symmetric=True
)
ap_scorer = roc_auc_scorer(classifier, X_test, y_test)
assert ap_scorer == pytest.approx(ap_proba)