Skip to content

MNT refactor _get_response_values #21538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
32 changes: 22 additions & 10 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
RegressorMixin,
clone,
MetaEstimatorMixin,
is_classifier,
)
from .preprocessing import label_binarize, LabelEncoder
from .utils import (
Expand All @@ -34,7 +33,10 @@
check_matplotlib_support,
)

from .utils.multiclass import check_classification_targets
from .utils.multiclass import (
check_classification_targets,
type_of_target,
)
from .utils.fixes import delayed
from .utils.validation import (
_check_fit_params,
Expand All @@ -43,12 +45,12 @@
check_consistent_length,
check_is_fitted,
)
from .utils import _safe_indexing
from .utils import _get_response_values, _safe_indexing
from .isotonic import IsotonicRegression
from .svm import LinearSVC
from .model_selection import check_cv, cross_val_predict
from .metrics._base import _check_pos_label_consistency
from .metrics._plot.base import _get_response
from .metrics._plot.base import _check_estimator_and_target_is_binary


class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down Expand Up @@ -1235,11 +1237,15 @@ def from_estimator(
method_name = f"{cls.__name__}.from_estimator"
check_matplotlib_support(method_name)

if not is_classifier(estimator):
raise ValueError("'estimator' should be a fitted classifier.")

y_prob, pos_label = _get_response(
X, estimator, response_method="predict_proba", pos_label=pos_label
target_type = type_of_target(y)
_check_estimator_and_target_is_binary(estimator, y, target_type=target_type)
y_prob, pos_label = _get_response_values(
estimator,
X,
y,
response_method="predict_proba",
pos_label=pos_label,
target_type=target_type,
)

name = name if name is not None else estimator.__class__.__name__
Expand Down Expand Up @@ -1352,9 +1358,15 @@ def from_predictions(
>>> disp = CalibrationDisplay.from_predictions(y_test, y_prob)
>>> plt.show()
"""
method_name = f"{cls.__name__}.from_estimator"
method_name = f"{cls.__name__}.from_predictions"
check_matplotlib_support(method_name)

target_type = type_of_target(y_true)
if target_type != "binary":
raise ValueError(
f"The target y is not binary. Got {target_type} type of target."
)

prob_true, prob_pred = calibration_curve(
y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label
)
Expand Down
34 changes: 17 additions & 17 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
from ..utils import Bunch
from ..utils.metaestimators import available_if
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.validation import column_or_1d
from ..utils.validation import (
_check_response_method,
check_is_fitted,
column_or_1d,
)
from ..utils.fixes import delayed
from ..utils.validation import _check_feature_names_in

Expand Down Expand Up @@ -120,21 +123,18 @@ def _concatenate_predictions(self, X, predictions):
def _method_name(name, estimator, method):
if estimator == "drop":
return None
if method == "auto":
if getattr(estimator, "predict_proba", None):
return "predict_proba"
elif getattr(estimator, "decision_function", None):
return "decision_function"
else:
return "predict"
else:
if not hasattr(estimator, method):
raise ValueError(
"Underlying estimator {} does not implement the method {}.".format(
name, method
)
)
return method
method = (
["predict_proba", "decision_function", "predict"]
if method == "auto"
else method
)
Comment on lines +126 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit lighter to parse IMO (and also what is used in other places in this PR):

Suggested change
method = (
["predict_proba", "decision_function", "predict"]
if method == "auto"
else method
)
if method == "auto":
method = ["predict_proba", "decision_function", "predict"]

try:
method_name = _check_response_method(estimator, method).__name__
except AttributeError as e:
raise ValueError(
f"Underlying estimator {name} does not implement the method {method}."
) from e
return method_name

def fit(self, X, y, sample_weight=None):
"""Fit the estimators.
Expand Down
150 changes: 45 additions & 105 deletions sklearn/metrics/_plot/base.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,56 @@
from ...base import is_classifier
from ...exceptions import NotFittedError
from ...utils.multiclass import type_of_target
from ...utils.validation import check_is_fitted


def _check_classifier_response_method(estimator, response_method):
"""Return prediction method from the response_method
def _check_estimator_and_target_is_binary(estimator, y, target_type=None):
"""Helper to check that estimator is a binary classifier and y is binary.

Parameters
----------
estimator: object
Classifier to check

response_method: {'auto', 'predict_proba', 'decision_function'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

Returns
-------
prediction_method: callable
prediction method of estimator
"""

if response_method not in ("predict_proba", "decision_function", "auto"):
raise ValueError(
"response_method must be 'predict_proba', 'decision_function' or 'auto'"
)

error_msg = "response method {} is not defined in {}"
if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(
error_msg.format(response_method, estimator.__class__.__name__)
)
else:
predict_proba = getattr(estimator, "predict_proba", None)
decision_function = getattr(estimator, "decision_function", None)
prediction_method = predict_proba or decision_function
if prediction_method is None:
raise ValueError(
error_msg.format(
"decision_function or predict_proba", estimator.__class__.__name__
)
)

return prediction_method


def _get_response(X, estimator, response_method, pos_label=None):
"""Return response and positive label.

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

estimator : estimator instance
Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline`
in which the last estimator is a classifier.

response_method: {'auto', 'predict_proba', 'decision_function'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

pos_label : str or int, default=None
The class considered as the positive class when computing
the metrics. By default, `estimators.classes_[1]` is
considered as the positive class.

Returns
-------
y_pred: ndarray of shape (n_samples,)
Target scores calculated from the provided response_method
and pos_label.

pos_label: str or int
The class considered as the positive class when computing
the metrics.
An estimator that should be used to predict the target.

y : ndarray
The associated target.

target_type : str, default=None
The type of the target `y` as returned by
:func:`~sklearn.utils.multiclass.type_of_target`. If `None`, the type
will be inferred by calling :func:`~sklearn.utils.multiclass.type_of_target`.
Providing the type of the target could save time by avoid calling the
:func:`~sklearn.utils.multiclass.type_of_target` function.

Raises
------
ValueError
If the estimator or the target are not binary.
"""
classification_error = (
"Expected 'estimator' to be a binary classifier, but got"
f" {estimator.__class__.__name__}"
)
try:
check_is_fitted(estimator)
except NotFittedError as e:
raise NotFittedError(
f"This {estimator.__class__.__name__} instance is not fitted yet. Call "
"'fit' with appropriate arguments before intending to use it to plotting "
"functionalities."
Comment on lines +35 to +36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a bit lighter to read (better suggestion welcome):

Suggested change
"'fit' with appropriate arguments before intending to use it to plotting "
"functionalities."
"'fit' with appropriate arguments before using it for plotting "
"functionalities."

) from e

if not is_classifier(estimator):
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(estimator, response_method)
y_pred = prediction_method(X)
if pos_label is not None:
try:
class_idx = estimator.classes_.tolist().index(pos_label)
except ValueError as e:
raise ValueError(
"The class provided by 'pos_label' is unknown. Got "
f"{pos_label} instead of one of {set(estimator.classes_)}"
) from e
else:
class_idx = 1
pos_label = estimator.classes_[class_idx]

if y_pred.ndim != 1: # `predict_proba`
y_pred_shape = y_pred.shape[1]
if y_pred_shape != 2:
raise ValueError(
f"{classification_error} fit on multiclass ({y_pred_shape} classes)"
" data"
)
y_pred = y_pred[:, class_idx]
elif pos_label == estimator.classes_[0]: # `decision_function`
y_pred *= -1
raise ValueError(
"This plotting functionalities only support a binary classifier. "
f"Got a {estimator.__class__.__name__} instead."
)
elif len(estimator.classes_) != 2:
raise ValueError(
f"This {estimator.__class__.__name__} instance is not a binary "
"classifier. It was fitted on multiclass problem with "
f"{len(estimator.classes_)} classes."
)

return y_pred, pos_label
if target_type is None:
target_type = type_of_target(y)
if target_type != "binary":
raise ValueError(
f"The target y is not binary. Got {target_type} type of target."
)
37 changes: 30 additions & 7 deletions sklearn/metrics/_plot/det_curve.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import scipy as sp

from .base import _get_response
from .base import _check_estimator_and_target_is_binary

from .. import det_curve
from .._base import _check_pos_label_consistency

from ...utils import check_matplotlib_support
from ...utils import deprecated
from ...utils import (
check_matplotlib_support,
deprecated,
_get_response_values,
)
from ...utils.multiclass import type_of_target


class DetCurveDisplay:
Expand Down Expand Up @@ -168,13 +172,20 @@ def from_estimator(
"""
check_matplotlib_support(f"{cls.__name__}.from_estimator")

target_type = type_of_target(y)
_check_estimator_and_target_is_binary(estimator, y, target_type=target_type)
if response_method == "auto":
response_method = ["predict_proba", "decision_function"]

name = estimator.__class__.__name__ if name is None else name

y_pred, pos_label = _get_response(
X,
y_pred, pos_label = _get_response_values(
estimator,
X,
y,
response_method,
pos_label=pos_label,
target_type=target_type,
)

return cls.from_predictions(
Expand Down Expand Up @@ -265,6 +276,13 @@ def from_predictions(
>>> plt.show()
"""
check_matplotlib_support(f"{cls.__name__}.from_predictions")

target_type = type_of_target(y_true)
if target_type != "binary":
raise ValueError(
f"The target y is not binary. Got {target_type} type of target."
)
Comment on lines +280 to +284
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this check here, since we do the check in det_curve?

if len(np.unique(y_true)) != 2:
raise ValueError(
"Only one class present in y_true. Detection error "
"tradeoff curve is not defined in that case."
)


fpr, fnr, _ = det_curve(
y_true,
y_pred,
Expand Down Expand Up @@ -454,8 +472,13 @@ def plot_det_curve(
"""
check_matplotlib_support("plot_det_curve")

y_pred, pos_label = _get_response(
X, estimator, response_method, pos_label=pos_label
target_type = type_of_target(y)
_check_estimator_and_target_is_binary(estimator, y, target_type=target_type)
if response_method == "auto":
response_method = ["predict_proba", "decision_function"]

y_pred, pos_label = _get_response_values(
estimator, X, y, response_method, pos_label=pos_label, target_type=target_type
)

fpr, fnr, _ = det_curve(
Expand Down
Loading