Skip to content

FIX select the probability estimates or transform the decision values when pos_label is provided #18114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a565589
TST wip
glemaitre Aug 7, 2020
6c0db49
TST wip
glemaitre Aug 7, 2020
c41e999
TST wip
glemaitre Aug 7, 2020
f53f833
TST PEP8 + comments
glemaitre Aug 7, 2020
dd4e9fe
TST force to use predict_proba as well
glemaitre Aug 7, 2020
e32cfa7
DOC add whats + PEP8
glemaitre Aug 7, 2020
07915e9
TST add some tolerance since the average of squared in diff ordered
glemaitre Aug 7, 2020
fc1c422
STY add better error message and refactor code
glemaitre Aug 10, 2020
aa5cd16
fix
glemaitre Aug 10, 2020
a669ecf
fix
glemaitre Aug 10, 2020
a477e7b
Update sklearn/metrics/tests/test_score_objects.py
glemaitre Aug 11, 2020
09b47bb
add test for PredictScorer
glemaitre Aug 11, 2020
42e7f00
apply olivier suggestions
glemaitre Aug 18, 2020
e9d7873
use list
glemaitre Aug 18, 2020
7025768
fix
glemaitre Aug 19, 2020
536753f
fix
glemaitre Aug 19, 2020
6a12a1f
PEP8
glemaitre Aug 19, 2020
38fc931
Merge branch 'master' into is/scorer_pos_label
glemaitre Aug 24, 2020
40b3d0d
itter
glemaitre Sep 2, 2020
c2b97b5
iter
glemaitre Sep 3, 2020
eb18c83
iter
glemaitre Sep 3, 2020
e9f608d
iter
glemaitre Sep 3, 2020
e612795
Merge remote-tracking branch 'origin/master' into is/scorer_pos_label
glemaitre Oct 5, 2020
fc86836
Merge remote-tracking branch 'glemaitre/is/scorer_pos_label' into is/…
glemaitre Oct 5, 2020
59a3e8f
iter
glemaitre Oct 5, 2020
bbe0e93
only select probab in the binary case
glemaitre Oct 5, 2020
98a745c
add small comment
glemaitre Oct 5, 2020
19423af
add assert suggested by ogrisel
glemaitre Oct 5, 2020
ab30367
iter
glemaitre Oct 5, 2020
89c04f0
avoid warning
glemaitre Oct 5, 2020
b3b0bfd
Update sklearn/metrics/tests/test_score_objects.py
glemaitre Oct 8, 2020
a8d2c29
iter
glemaitre Oct 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,13 @@ Changelog
``labels`` parameter.
:pr:`17935` by :user:`Cary Goltermann <Ultramann>`.

- |Fix| Fix scorers that accept a pos_label parameter and compute their metrics
from values returned by `decision_function` or `predict_proba`. Previously,
they would return erroneous values when pos_label was not corresponding to
`classifier.classes_[1]`. This is especially important when training
classifiers directly with string labeled target classes.
:pr:`#18114` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
70 changes: 53 additions & 17 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,43 @@ def __init__(self, score_func, sign, kwargs):
self._score_func = score_func
self._sign = sign

@staticmethod
def _check_pos_label(pos_label, classes):
if pos_label not in list(classes):
raise ValueError(
f"pos_label={pos_label} is not a valid label: {classes}"
)

def _select_proba_binary(self, y_pred, classes):
"""Select the column of the positive label in `y_pred` when
probabilities are provided.

Parameters
----------
y_pred : ndarray of shape (n_samples, n_classes)
The prediction given by `predict_proba`.

classes : ndarray of shape (n_classes,)
The class labels for the estimator.

Returns
-------
y_pred : ndarray of shape (n_samples,)
Probability predictions of the positive class.
"""
if y_pred.shape[1] == 2:
pos_label = self._kwargs.get("pos_label", classes[1])
self._check_pos_label(pos_label, classes)
col_idx = np.flatnonzero(classes == pos_label)[0]
return y_pred[:, col_idx]

err_msg = (
f"Got predict_proba of shape {y_pred.shape}, but need "
f"classifier with two classes for {self._score_func.__name__} "
f"scoring"
)
raise ValueError(err_msg)

def __repr__(self):
kwargs_string = "".join([", %s=%s" % (str(k), str(v))
for k, v in self._kwargs.items()])
Expand Down Expand Up @@ -237,14 +274,11 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):

y_type = type_of_target(y)
y_pred = method_caller(clf, "predict_proba", X)
if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
elif y_pred.shape[1] == 1: # not multiclass
raise ValueError('got predict_proba of shape {},'
' but need classifier with two'
' classes for {} scoring'.format(
y_pred.shape, self._score_func.__name__))
if y_type == "binary" and y_pred.shape[1] <= 2:
# `y_type` could be equal to "binary" even in a multi-class
# problem: (when only 2 class are given to `y_true` during scoring)
# Thus, we need to check for the shape of `y_pred`.
y_pred = self._select_proba_binary(y_pred, clf.classes_)
if sample_weight is not None:
return self._sign * self._score_func(y, y_pred,
sample_weight=sample_weight,
Expand Down Expand Up @@ -298,22 +332,24 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
try:
y_pred = method_caller(clf, "decision_function", X)

# For multi-output multi-class estimator
if isinstance(y_pred, list):
# For multi-output multi-class estimator
y_pred = np.vstack([p for p in y_pred]).T
elif y_type == "binary" and "pos_label" in self._kwargs:
self._check_pos_label(
self._kwargs["pos_label"], clf.classes_
)
if self._kwargs["pos_label"] == clf.classes_[0]:
# The implicit positive class of the binary classifier
# does not match `pos_label`: we need to invert the
# predictions
y_pred *= -1

except (NotImplementedError, AttributeError):
y_pred = method_caller(clf, "predict_proba", X)

if y_type == "binary":
if y_pred.shape[1] == 2:
y_pred = y_pred[:, 1]
else:
raise ValueError('got predict_proba of shape {},'
' but need classifier with two'
' classes for {} scoring'.format(
y_pred.shape,
self._score_func.__name__))
y_pred = self._select_proba_binary(y_pred, clf.classes_)
elif isinstance(y_pred, list):
y_pred = np.vstack([p[:, -1] for p in y_pred]).T

Expand Down
226 changes: 223 additions & 3 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import pickle
import tempfile
import shutil
Expand All @@ -16,9 +17,18 @@
from sklearn.utils._testing import ignore_warnings

from sklearn.base import BaseEstimator
from sklearn.metrics import (f1_score, r2_score, roc_auc_score, fbeta_score,
log_loss, precision_score, recall_score,
jaccard_score)
from sklearn.metrics import (
average_precision_score,
brier_score_loss,
f1_score,
fbeta_score,
jaccard_score,
log_loss,
precision_score,
r2_score,
recall_score,
roc_auc_score,
)
from sklearn.metrics import cluster as cluster_module
from sklearn.metrics import check_scoring
from sklearn.metrics._scorer import (_PredictScorer, _passthrough_scorer,
Expand Down Expand Up @@ -618,6 +628,8 @@ def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count,
mock_est.predict = predict_func
mock_est.predict_proba = predict_proba_func
mock_est.decision_function = decision_function_func
# add the classes that would be found during fit
mock_est.classes_ = np.array([0, 1])

scorer_dict = _check_multimetric_scoring(LogisticRegression(), scorers)
multi_scorer = _MultimetricScorer(**scorer_dict)
Expand Down Expand Up @@ -747,3 +759,211 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
msg = "'Perceptron' object has no attribute 'predict_proba'"
with pytest.raises(AttributeError, match=msg):
scorer(lr, X, y)


@pytest.fixture
def string_labeled_classification_problem():
"""Train a classifier on binary problem with string target.

The classifier is trained on a binary classification problem where the
minority class of interest has a string label that is intentionally not the
greatest class label using the lexicographic order. In this case, "cancer"
is the positive label, and `classifier.classes_` is
`["cancer", "not cancer"]`.

In addition, the dataset is imbalanced to better identify problems when
using non-symmetric performance metrics such as f1-score, average precision
and so on.

Returns
-------
classifier : estimator object
Trained classifier on the binary problem.
X_test : ndarray of shape (n_samples, n_features)
Data to be used as testing set in tests.
y_test : ndarray of shape (n_samples,), dtype=object
Binary target where labels are strings.
y_pred : ndarray of shape (n_samples,), dtype=object
Prediction of `classifier` when predicting for `X_test`.
y_pred_proba : ndarray of shape (n_samples, 2), dtype=np.float64
Probabilities of `classifier` when predicting for `X_test`.
y_pred_decision : ndarray of shape (n_samples,), dtype=np.float64
Decision function values of `classifier` when predicting on `X_test`.
"""
from sklearn.datasets import load_breast_cancer
from sklearn.utils import shuffle

X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced classification task
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0,
)
classifier = LogisticRegression().fit(X_train, y_train)
y_pred = classifier.predict(X_test)
y_pred_proba = classifier.predict_proba(X_test)
y_pred_decision = classifier.decision_function(X_test)

return classifier, X_test, y_test, y_pred, y_pred_proba, y_pred_decision


def test_average_precision_pos_label(string_labeled_classification_problem):
# check that _ThresholdScorer will lead to the right score when passing
# `pos_label`. Currently, only `average_precision_score` is defined to
# be such a scorer.
clf, X_test, y_test, _, y_pred_proba, y_pred_decision = \
string_labeled_classification_problem

pos_label = "cancer"
# we need to select the positive column or reverse the decision values
y_pred_proba = y_pred_proba[:, 0]
y_pred_decision = y_pred_decision * -1
assert clf.classes_[0] == pos_label

# check that when calling the scoring function, probability estimates and
# decision values lead to the same results
ap_proba = average_precision_score(
y_test, y_pred_proba, pos_label=pos_label
)
ap_decision_function = average_precision_score(
y_test, y_pred_decision, pos_label=pos_label
)
assert ap_proba == pytest.approx(ap_decision_function)

# create a scorer which would require to pass a `pos_label`
# check that it fails if `pos_label` is not provided
average_precision_scorer = make_scorer(
average_precision_score, needs_threshold=True,
)
err_msg = "pos_label=1 is not a valid label. It should be one of "
with pytest.raises(ValueError, match=err_msg):
average_precision_scorer(clf, X_test, y_test)

# otherwise, the scorer should give the same results than calling the
# scoring function
average_precision_scorer = make_scorer(
average_precision_score, needs_threshold=True, pos_label=pos_label
)
ap_scorer = average_precision_scorer(clf, X_test, y_test)

assert ap_scorer == pytest.approx(ap_proba)

# The above scorer call is using `clf.decision_function`. We will force
# it to use `clf.predict_proba`.
clf_without_predict_proba = deepcopy(clf)

def _predict_proba(self, X):
raise NotImplementedError

clf_without_predict_proba.predict_proba = partial(
_predict_proba, clf_without_predict_proba
)
# sanity check
with pytest.raises(NotImplementedError):
clf_without_predict_proba.predict_proba(X_test)

ap_scorer = average_precision_scorer(
clf_without_predict_proba, X_test, y_test
)
assert ap_scorer == pytest.approx(ap_proba)


def test_brier_score_loss_pos_label(string_labeled_classification_problem):
# check that _ProbaScorer leads to the right score when `pos_label` is
# provided. Currently only the `brier_score_loss` is defined to be such
# a scorer.
clf, X_test, y_test, _, y_pred_proba, _ = \
string_labeled_classification_problem

pos_label = "cancer"
assert clf.classes_[0] == pos_label

# brier score loss is symmetric
brier_pos_cancer = brier_score_loss(
y_test, y_pred_proba[:, 0], pos_label="cancer"
)
brier_pos_not_cancer = brier_score_loss(
y_test, y_pred_proba[:, 1], pos_label="not cancer"
)
assert brier_pos_cancer == pytest.approx(brier_pos_not_cancer)

brier_scorer = make_scorer(
brier_score_loss, needs_proba=True, pos_label=pos_label,
)
assert brier_scorer(clf, X_test, y_test) == pytest.approx(brier_pos_cancer)


@pytest.mark.parametrize(
"score_func", [f1_score, precision_score, recall_score, jaccard_score]
)
def test_non_symmetric_metric_pos_label(
score_func, string_labeled_classification_problem
):
# check that _PredictScorer leads to the right score when `pos_label` is
# provided. We check for all possible metric supported.
# Note: At some point we may end up having "scorer tags".
clf, X_test, y_test, y_pred, _, _ = string_labeled_classification_problem

pos_label = "cancer"
assert clf.classes_[0] == pos_label

score_pos_cancer = score_func(y_test, y_pred, pos_label="cancer")
score_pos_not_cancer = score_func(y_test, y_pred, pos_label="not cancer")

assert score_pos_cancer != pytest.approx(score_pos_not_cancer)

scorer = make_scorer(score_func, pos_label=pos_label)
assert scorer(clf, X_test, y_test) == pytest.approx(score_pos_cancer)


@pytest.mark.parametrize(
"scorer",
[
make_scorer(
average_precision_score, needs_threshold=True, pos_label="xxx"
),
make_scorer(brier_score_loss, needs_proba=True, pos_label="xxx"),
make_scorer(f1_score, pos_label="xxx")
],
ids=["ThresholdScorer", "ProbaScorer", "PredictScorer"],
)
def test_scorer_select_proba_error(scorer):
# check that we raise the the proper error when passing an unknown
# pos_label
X, y = make_classification(
n_classes=2, n_informative=3, n_samples=20, random_state=0
)
lr = LogisticRegression().fit(X, y)
assert scorer._kwargs["pos_label"] not in np.unique(y).tolist()

err_msg = "is not a valid label"
with pytest.raises(ValueError, match=err_msg):
scorer(lr, X, y)


def test_scorer_no_op_multiclass_select_proba():
# check that calling a ProbaScorer on a multiclass problem do not raise
# even if `y_true` would be binary during the scoring.
# `_select_proba_binary` should not be called in this case.
X, y = make_classification(
n_classes=3, n_informative=3, n_samples=20, random_state=0
)
lr = LogisticRegression().fit(X, y)

mask_last_class = y == lr.classes_[-1]
X_test, y_test = X[~mask_last_class], y[~mask_last_class]
assert_array_equal(np.unique(y_test), lr.classes_[:-1])

scorer = make_scorer(
roc_auc_score, needs_proba=True, multi_class="ovo", labels=lr.classes_,
)
scorer(lr, X_test, y_test)