Skip to content

MAINT refactor scorer using _get_response_values #26037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 4, 2023
91 changes: 26 additions & 65 deletions sklearn/metrics/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
# Arnaud Joly <arnaud.v.joly@gmail.com>
# License: Simplified BSD

from functools import partial
from collections import Counter
from inspect import signature
from functools import partial
from traceback import format_exc

import numpy as np
Expand Down Expand Up @@ -64,20 +65,23 @@

from ..utils.multiclass import type_of_target
from ..base import is_regressor
from ..utils._response import _get_response_values
from ..utils._param_validation import HasMethods, StrOptions, validate_params


def _cached_call(cache, estimator, method, *args, **kwargs):
def _cached_call(cache, estimator, response_method, *args, **kwargs):
"""Call estimator with method and args and kwargs."""
if cache is None:
return getattr(estimator, method)(*args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this not simply replacing getattr with _get_response_values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a change requested by @jeremiedbb: #26037 (comment)

It makes the code more readable as @jeremiedbb argued (with a larger diff).

if cache is not None and response_method in cache:
return cache[response_method]

result, _ = _get_response_values(
estimator, *args, response_method=response_method, **kwargs
)

if cache is not None:
cache[response_method] = result

try:
return cache[method]
except KeyError:
result = getattr(estimator, method)(*args, **kwargs)
cache[method] = result
return result
return result


class _MultimetricScorer:
Expand Down Expand Up @@ -162,40 +166,13 @@ def __init__(self, score_func, sign, kwargs):
self._score_func = score_func
self._sign = sign

@staticmethod
def _check_pos_label(pos_label, classes):
if pos_label not in list(classes):
raise ValueError(f"pos_label={pos_label} is not a valid label: {classes}")

def _select_proba_binary(self, y_pred, classes):
"""Select the column of the positive label in `y_pred` when
probabilities are provided.

Parameters
----------
y_pred : ndarray of shape (n_samples, n_classes)
The prediction given by `predict_proba`.

classes : ndarray of shape (n_classes,)
The class labels for the estimator.

Returns
-------
y_pred : ndarray of shape (n_samples,)
Probability predictions of the positive class.
"""
if y_pred.shape[1] == 2:
pos_label = self._kwargs.get("pos_label", classes[1])
self._check_pos_label(pos_label, classes)
col_idx = np.flatnonzero(classes == pos_label)[0]
return y_pred[:, col_idx]

err_msg = (
f"Got predict_proba of shape {y_pred.shape}, but need "
f"classifier with two classes for {self._score_func.__name__} "
"scoring"
)
raise ValueError(err_msg)
def _get_pos_label(self):
if "pos_label" in self._kwargs:
return self._kwargs["pos_label"]
score_func_params = signature(self._score_func).parameters
if "pos_label" in score_func_params:
return score_func_params["pos_label"].default
return None

def __repr__(self):
kwargs_string = "".join(
Expand Down Expand Up @@ -311,14 +288,7 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
score : float
Score function applied to prediction of estimator on X.
"""

y_type = type_of_target(y)
y_pred = method_caller(clf, "predict_proba", X)
if y_type == "binary" and y_pred.shape[1] <= 2:
# `y_type` could be equal to "binary" even in a multi-class
# problem: (when only 2 class are given to `y_true` during scoring)
# Thus, we need to check for the shape of `y_pred`.
y_pred = self._select_proba_binary(y_pred, clf.classes_)
y_pred = method_caller(clf, "predict_proba", X, pos_label=self._get_pos_label())
if sample_weight is not None:
return self._sign * self._score_func(
y, y_pred, sample_weight=sample_weight, **self._kwargs
Expand Down Expand Up @@ -369,26 +339,17 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
if is_regressor(clf):
y_pred = method_caller(clf, "predict", X)
else:
pos_label = self._get_pos_label()
try:
y_pred = method_caller(clf, "decision_function", X)
y_pred = method_caller(clf, "decision_function", X, pos_label=pos_label)

if isinstance(y_pred, list):
# For multi-output multi-class estimator
y_pred = np.vstack([p for p in y_pred]).T
elif y_type == "binary" and "pos_label" in self._kwargs:
self._check_pos_label(self._kwargs["pos_label"], clf.classes_)
if self._kwargs["pos_label"] == clf.classes_[0]:
# The implicit positive class of the binary classifier
# does not match `pos_label`: we need to invert the
# predictions
y_pred *= -1

except (NotImplementedError, AttributeError):
y_pred = method_caller(clf, "predict_proba", X)

if y_type == "binary":
y_pred = self._select_proba_binary(y_pred, clf.classes_)
elif isinstance(y_pred, list):
y_pred = method_caller(clf, "predict_proba", X, pos_label=pos_label)
if isinstance(y_pred, list):
y_pred = np.vstack([p[:, -1] for p in y_pred]).T

if sample_weight is not None:
Expand Down
15 changes: 10 additions & 5 deletions sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,13 +759,18 @@ def test_multimetric_scorer_calls_method_once(
X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0])

mock_est = Mock()
fit_func = Mock(return_value=mock_est)
predict_func = Mock(return_value=y)
mock_est._estimator_type = "classifier"
fit_func = Mock(return_value=mock_est, name="fit")
fit_func.__name__ = "fit"
predict_func = Mock(return_value=y, name="predict")
predict_func.__name__ = "predict"

pos_proba = np.random.rand(X.shape[0])
proba = np.c_[1 - pos_proba, pos_proba]
predict_proba_func = Mock(return_value=proba)
decision_function_func = Mock(return_value=pos_proba)
predict_proba_func = Mock(return_value=proba, name="predict_proba")
predict_proba_func.__name__ = "predict_proba"
decision_function_func = Mock(return_value=pos_proba, name="decision_function")
decision_function_func.__name__ = "decision_function"

mock_est.fit = fit_func
mock_est.predict = predict_func
Expand Down Expand Up @@ -961,7 +966,7 @@ def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
n_classes=3, n_informative=3, n_samples=20, random_state=0
)
lr = Perceptron().fit(X, y)
msg = "'Perceptron' object has no attribute 'predict_proba'"
msg = "Perceptron has none of the following attributes: predict_proba."
with pytest.raises(AttributeError, match=msg):
scorer(lr, X, y)

Expand Down
10 changes: 0 additions & 10 deletions sklearn/utils/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ def _get_response_values(
The response values are predictions, one scalar value for each sample in X
that depends on the specific choice of `response_method`.

This helper only accepts multiclass classifiers with the `predict` response
method.

If `estimator` is a binary classifier, also return the label for the
effective positive class.

Expand Down Expand Up @@ -75,15 +72,8 @@ def _get_response_values(
if is_classifier(estimator):
prediction_method = _check_response_method(estimator, response_method)
classes = estimator.classes_

target_type = "binary" if len(classes) <= 2 else "multiclass"

if target_type == "multiclass" and prediction_method.__name__ != "predict":
raise ValueError(
"With a multiclass estimator, the response method should be "
f"predict, got {prediction_method.__name__} instead."
)

if pos_label is not None and pos_label not in classes.tolist():
raise ValueError(
f"pos_label={pos_label} is not a valid label: It should be "
Expand Down
45 changes: 25 additions & 20 deletions sklearn/utils/tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
LinearRegression,
LogisticRegression,
)
from sklearn.svm import SVC
from sklearn.preprocessing import scale
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._mocking import _MockEstimatorOnOffPrediction
from sklearn.utils._testing import assert_allclose, assert_array_equal
Expand All @@ -15,6 +15,8 @@


X, y = load_iris(return_X_y=True)
# scale the data to avoid ConvergenceWarning with LogisticRegression
X = scale(X, copy=False)
X_binary, y_binary = X[:100], y[:100]


Expand All @@ -29,25 +31,6 @@ def test_get_response_values_regressor_error(response_method):
_get_response_values(my_estimator, X, response_method=response_method)


@pytest.mark.parametrize(
"estimator, response_method",
[
(DecisionTreeClassifier(), "predict_proba"),
(SVC(), "decision_function"),
],
)
def test_get_response_values_error_multiclass_classifier(estimator, response_method):
"""Check that we raise an error with multiclass classifier and requesting
response values different from `predict`."""
X, y = make_classification(
n_samples=10, n_clusters_per_class=1, n_classes=3, random_state=0
)
classifier = estimator.fit(X, y)
err_msg = "With a multiclass estimator, the response method should be predict"
with pytest.raises(ValueError, match=err_msg):
_get_response_values(classifier, X, response_method=response_method)


def test_get_response_values_regressor():
"""Check the behaviour of `_get_response_values` with regressor."""
X, y = make_regression(n_samples=10, random_state=0)
Expand Down Expand Up @@ -227,3 +210,25 @@ def test_get_response_decision_function():
)
np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1)
assert pos_label == 0


@pytest.mark.parametrize(
"estimator, response_method",
[
(DecisionTreeClassifier(max_depth=2, random_state=0), "predict_proba"),
(LogisticRegression(), "decision_function"),
],
)
def test_get_response_values_multiclass(estimator, response_method):
"""Check that we can call `_get_response_values` with a multiclass estimator.
It should return the predictions untouched.
"""
estimator.fit(X, y)
predictions, pos_label = _get_response_values(
estimator, X, response_method=response_method
)

assert pos_label is None
assert predictions.shape == (X.shape[0], len(estimator.classes_))
if response_method == "predict_proba":
assert np.logical_and(predictions >= 0, predictions <= 1).all()