From f1485933c1662605d45c401a789aacae11f1c8db Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 7 Apr 2022 17:51:58 +0200 Subject: [PATCH 01/15] introduce check_response_methods in stacking --- sklearn/ensemble/_stacking.py | 32 +++++++------- sklearn/utils/_mocking.py | 42 ++++++++++++++++++ sklearn/utils/tests/test_mocking.py | 30 ++++++++++++- sklearn/utils/tests/test_validation.py | 60 +++++++++++++++++++++++++- sklearn/utils/validation.py | 47 +++++++++++++++++++- 5 files changed, 191 insertions(+), 20 deletions(-) diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index a55402e02ef7c..eea0a8701901a 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -30,10 +30,13 @@ from ..utils import Bunch from ..utils.metaestimators import available_if from ..utils.multiclass import check_classification_targets -from ..utils.validation import check_is_fitted -from ..utils.validation import column_or_1d from ..utils.fixes import delayed -from ..utils.validation import _check_feature_names_in +from ..utils.validation import ( + _check_feature_names_in, + _check_response_method, + check_is_fitted, + column_or_1d, +) def _estimator_has(attr): @@ -121,20 +124,15 @@ def _method_name(name, estimator, method): if estimator == "drop": return None if method == "auto": - if getattr(estimator, "predict_proba", None): - return "predict_proba" - elif getattr(estimator, "decision_function", None): - return "decision_function" - else: - return "predict" - else: - if not hasattr(estimator, method): - raise ValueError( - "Underlying estimator {} does not implement the method {}.".format( - name, method - ) - ) - return method + method = ["predict_proba", "decision_function", "predict"] + try: + method_name = _check_response_method(estimator, method).__name__ + except AttributeError as e: + raise ValueError( + f"Underlying estimator {name} does not implement the method {method}." + ) from e + + return method_name def fit(self, X, y, sample_weight=None): """Fit the estimators. diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index c7451dce1fbc5..688bfb68ed484 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -1,6 +1,7 @@ import numpy as np from ..base import BaseEstimator, ClassifierMixin +from .metaestimators import available_if from .validation import _check_sample_weight, _num_samples, check_array from .validation import check_is_fitted @@ -344,3 +345,44 @@ def predict_proba(self, X): def _more_tags(self): return {"_skip_test": True} + + +def _check_response(method): + def check(self): + return self.response_methods is not None and method in self.response_methods + + return check + + +class _MockEstimatorOnOffPrediction(BaseEstimator): + """Estimator for which we can turn on/off the prediction methods. + + Parameters + ---------- + response_methods: list of \ + {"predict", "predict_proba", "decision_function"}, default=None + List containing the response implemented by the estimator. When, the + response is in the list, it will return the name of the response method + when called. Otherwise, an `AttributeError` is raised. It allows to + use `getattr` as any conventional estimator. By default, no response + methods are mocked. + """ + + def __init__(self, response_methods=None): + self.response_methods = response_methods + + def fit(self, X, y): + self.classes_ = np.unique(y) + return self + + @available_if(_check_response("predict")) + def predict(self, X): + return "predict" + + @available_if(_check_response("predict_proba")) + def predict_proba(self, X): + return "predict_proba" + + @available_if(_check_response("decision_function")) + def decision_function(self, X): + return "decision_function" diff --git a/sklearn/utils/tests/test_mocking.py b/sklearn/utils/tests/test_mocking.py index a12c41256581a..8cdac2e6a53f3 100644 --- a/sklearn/utils/tests/test_mocking.py +++ b/sklearn/utils/tests/test_mocking.py @@ -10,7 +10,10 @@ from sklearn.utils import _safe_indexing from sklearn.utils._testing import _convert_container -from sklearn.utils._mocking import CheckingClassifier +from sklearn.utils._mocking import ( + _MockEstimatorOnOffPrediction, + CheckingClassifier, +) @pytest.fixture @@ -181,3 +184,28 @@ def test_checking_classifier_methods_to_check(iris, methods_to_check, predict_me getattr(clf, predict_method)(X) else: getattr(clf, predict_method)(X) + + +@pytest.mark.parametrize( + "response_methods", + [ + ["predict"], + ["predict", "predict_proba"], + ["predict", "decision_function"], + ["predict", "predict_proba", "decision_function"], + ], +) +def test_mock_estimator_on_off_prediction(iris, response_methods): + X, y = iris + estimator = _MockEstimatorOnOffPrediction(response_methods=response_methods) + + estimator.fit(X, y) + assert hasattr(estimator, "classes_") + assert_array_equal(estimator.classes_, np.unique(y)) + + possible_responses = ["predict", "predict_proba", "decision_function"] + for response in possible_responses: + if response in response_methods: + assert hasattr(estimator, response) + else: + assert not hasattr(estimator, response) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 84a71a10981fb..0fbe0553b46a2 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -23,7 +23,13 @@ from sklearn.utils import as_float_array, check_array, check_symmetric from sklearn.utils import check_X_y from sklearn.utils import deprecated -from sklearn.utils._mocking import MockDataFrame + +# TODO: add this estimator into the _mocking module in a further refactoring +from sklearn.metrics.tests.test_score_objects import EstimatorWithFit +from sklearn.utils._mocking import ( + MockDataFrame, + _MockEstimatorOnOffPrediction, +) from sklearn.utils.fixes import parse_version from sklearn.utils.estimator_checks import _NotAnArray from sklearn.random_projection import _sparse_random_matrix @@ -52,6 +58,7 @@ _get_feature_names, _check_feature_names_in, _check_fit_params, + _check_response_method, ) from sklearn.base import BaseEstimator import sklearn @@ -1704,3 +1711,54 @@ def test_check_feature_names_in_pandas(): with pytest.raises(ValueError, match="input_features is not equal to"): est.get_feature_names_out(["x1", "x2", "x3"]) + + +def test_check_response_method_unknown_method(): + """Check the error message when passing an unknown response method.""" + err_msg = ( + "RandomForestRegressor has none of the following attributes: unknown_method." + ) + with pytest.raises(AttributeError, match=err_msg): + _check_response_method(RandomForestRegressor(), "unknown_method") + + +@pytest.mark.parametrize( + "response_method", ["decision_function", "predict_proba", "predict"] +) +def test_check_response_method_not_supported_response_method(response_method): + """Check the error message when a response method is not supported by the + estimator.""" + err_msg = ( + f"EstimatorWithFit has none of the following attributes: {response_method}." + ) + with pytest.raises(AttributeError, match=err_msg): + _check_response_method(EstimatorWithFit(), response_method) + + +def test_check_response_method_list_str(): + """Check that we can pass a list of ordered method.""" + method_implemented = ["predict_proba"] + my_estimator = _MockEstimatorOnOffPrediction(method_implemented) + + X = "mocking_data" + + # raise an error when no methods are defined + response_method = ["decision_function", "predict"] + err_msg = ( + "_MockEstimatorOnOffPrediction has none of the following attributes: " + f"{', '.join(response_method)}." + ) + with pytest.raises(AttributeError, match=err_msg): + _check_response_method(my_estimator, response_method)(X) + + # check that we don't get issue when one of the method is defined + response_method = ["decision_function", "predict_proba"] + method_name_predicting = _check_response_method(my_estimator, response_method)(X) + assert method_name_predicting == "predict_proba" + + # check the order of the methods returned + method_implemented = ["predict_proba", "predict"] + my_estimator = _MockEstimatorOnOffPrediction(method_implemented) + response_method = ["decision_function", "predict", "predict_proba"] + method_name_predicting = _check_response_method(my_estimator, response_method)(X) + assert method_name_predicting == "predict" diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index aba4e2b179953..8c72dc30fdbef 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -9,7 +9,7 @@ # Sylvain Marie # License: BSD 3 clause -from functools import wraps +from functools import reduce, wraps import warnings import numbers import operator @@ -1769,6 +1769,51 @@ def _allclose_dense_sparse(x, y, rtol=1e-7, atol=1e-9): ) +def _check_response_method(estimator, response_method): + """Check if `response_method` is available in estimator and return it. + + .. versionadded:: 1.1 + + Parameters + ---------- + estimator : estimator instance + Classifier or regressor to check. + response_method : {"predict_proba", "decision_function", "predict"} or \ + list of such str + Specifies the response method to use get prediction from an estimator + (i.e. :term:`predict_proba`, :term:`decision_function` or + :term:`predict`). Possible choices are: + - if `str`, it corresponds to the name to the method to return; + - if a list of `str`, it provides the method names in order of + preference. The method returned corresponds to the first method in + the list and which is implemented by `estimator`. + + Returns + ------- + prediction_method : callable + Prediction method of estimator. + + Raises + ------ + AttributeError + If `response_method` is not available in `estimator`. + """ + if isinstance(response_method, str): + list_methods = [response_method] + else: + list_methods = response_method + + prediction_method = [getattr(estimator, method, None) for method in list_methods] + prediction_method = reduce(lambda x, y: x or y, prediction_method) + if prediction_method is None: + raise AttributeError( + f"{estimator.__class__.__name__} has none of the following attributes: " + f"{', '.join(list_methods)}." + ) + + return prediction_method + + def _check_fit_params(X, fit_params, indices=None): """Check and validate the parameters passed during `fit`. From 2fb9a8eb770d2b4345a7a7dd7bbcc08121a496cb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 7 Apr 2022 18:16:39 +0200 Subject: [PATCH 02/15] add _get_response_values --- sklearn/calibration.py | 27 +++-- sklearn/tests/test_calibration.py | 10 +- sklearn/utils/__init__.py | 116 ++++++++++++++++++++++ sklearn/utils/tests/test_utils.py | 160 +++++++++++++++++++++++++++++- 4 files changed, 301 insertions(+), 12 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 684ec91ebb86b..7c7006392063f 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -32,9 +32,14 @@ column_or_1d, indexable, check_matplotlib_support, + _get_response_values, + _safe_indexing, ) -from .utils.multiclass import check_classification_targets +from .utils.multiclass import ( + check_classification_targets, + type_of_target, +) from .utils.fixes import delayed from .utils.validation import ( _check_fit_params, @@ -43,12 +48,10 @@ check_consistent_length, check_is_fitted, ) -from .utils import _safe_indexing from .isotonic import IsotonicRegression from .svm import LinearSVC from .model_selection import check_cv, cross_val_predict from .metrics._base import _check_pos_label_consistency -from .metrics._plot.base import _get_response class CalibratedClassifierCV(ClassifierMixin, MetaEstimatorMixin, BaseEstimator): @@ -1235,11 +1238,17 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) + target_type = type_of_target(y) if not is_classifier(estimator): raise ValueError("'estimator' should be a fitted classifier.") - y_prob, pos_label = _get_response( - X, estimator, response_method="predict_proba", pos_label=pos_label + y_prob, pos_label = _get_response_values( + estimator, + X, + y, + response_method="predict_proba", + pos_label=pos_label, + target_type=target_type, ) name = name if name is not None else estimator.__class__.__name__ @@ -1352,9 +1361,15 @@ def from_predictions( >>> disp = CalibrationDisplay.from_predictions(y_test, y_prob) >>> plt.show() """ - method_name = f"{cls.__name__}.from_estimator" + method_name = f"{cls.__name__}.from_predictions" check_matplotlib_support(method_name) + target_type = type_of_target(y_true) + if target_type != "binary": + raise ValueError( + f"The target y is not binary. Got {target_type} type of target." + ) + prob_true, prob_pred = calibration_curve( y_true, y_prob, n_bins=n_bins, strategy=strategy, pos_label=pos_label ) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index cdfdf4f97b78b..86a5f7e27c89d 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -198,7 +198,7 @@ def test_parallel_execution(data, method, ensemble): X, y = data X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) - base_estimator = LinearSVC(random_state=42) + base_estimator = make_pipeline(StandardScaler(), LinearSVC(random_state=42)) cal_clf_parallel = CalibratedClassifierCV( base_estimator, method=method, n_jobs=2, ensemble=ensemble @@ -636,8 +636,8 @@ def test_calibration_display_validation(pyplot, iris_data, iris_data_binary): CalibrationDisplay.from_estimator(reg, X, y) clf = LinearSVC().fit(X, y) - msg = "response method predict_proba is not defined in" - with pytest.raises(ValueError, match=msg): + msg = "has none of the following attributes: predict_proba." + with pytest.raises(AttributeError, match=msg): CalibrationDisplay.from_estimator(clf, X, y) clf = LogisticRegression() @@ -653,11 +653,11 @@ def test_calibration_display_non_binary(pyplot, iris_data, constructor_name): y_prob = clf.predict_proba(X) if constructor_name == "from_estimator": - msg = "to be a binary classifier, but got" + msg = "The target y is not binary. Got multiclass type of target." with pytest.raises(ValueError, match=msg): CalibrationDisplay.from_estimator(clf, X, y) else: - msg = "y should be a 1d array, got an array of shape" + msg = "The target y is not binary. Got multiclass type of target." with pytest.raises(ValueError, match=msg): CalibrationDisplay.from_predictions(y, y_prob) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 83ff96428a257..83664a18eee31 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -28,6 +28,7 @@ from .deprecation import deprecated from .fixes import parse_version, threadpool_info from ._estimator_html_repr import estimator_html_repr +from .multiclass import type_of_target from .validation import ( as_float_array, assert_all_finite, @@ -39,6 +40,7 @@ indexable, check_symmetric, check_scalar, + _check_response_method, ) from .. import get_config from ._bunch import Bunch @@ -1264,3 +1266,117 @@ def is_abstract(c): # itemgetter is used to ensure the sort does not extend to the 2nd item of # the tuple return sorted(set(estimators), key=itemgetter(0)) + + +def _get_response_values( + estimator, + X, + y_true, + response_method, + pos_label=None, + target_type=None, +): + """Compute the response values of a classifier or a regressor. + + The response values are predictions, one scalar value for each sample in X + that depends on the specific choice of `response_method`. + + This helper only accepts multiclass classifiers with the `predict` response + method. + + If `estimator` is a binary classifier, also return the label for the + effective positive class. + + .. versionadded:: 1.1 + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or regressor or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y_true : array-like of shape (n_samples,) + The true label. + + response_method : {"predict_proba", "decision_function", "predict"} or \ + list of such str + Specifies the response method to use get prediction from an estimator + (i.e. :term:`predict_proba`, :term:`decision_function` or + :term:`predict`). Possible choices are: + + - if `str`, it corresponds to the name to the method to return; + - if a list of `str`, it provides the method names in order of + preference. The method returned corresponds to the first method in + the list and which is implemented by `estimator`. + + pos_label : str or int, default=None + The class considered as the positive class when computing + the metrics. By default, `estimators.classes_[1]` is + considered as the positive class. + + target_type : str, default=None + The type of the target `y` as returned by + :func:`~sklearn.utils.multiclass.type_of_target`. If `None`, the type + will be inferred by calling :func:`~sklearn.utils.multiclass.type_of_target`. + Providing the type of the target could save time by avoid calling the + :func:`~sklearn.utils.multiclass.type_of_target` function. + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Target scores calculated from the provided response_method + and `pos_label`. + + pos_label : str, int or None + The class considered as the positive class when computing + the metrics. Returns `None` if `estimator` is a regressor. + + Raises + ------ + ValueError + If `pos_label` is not a valid label. + If the shape of `y_pred` is not consistent for binary classifier. + If the response method can be applied to a classifier only and + `estimator` is a regressor. + """ + from sklearn.base import is_classifier # noqa + + if is_classifier(estimator): + if target_type is None: + target_type = type_of_target(y_true) + prediction_method = _check_response_method(estimator, response_method) + y_pred = prediction_method(X) + classes = estimator.classes_ + + if pos_label is not None and pos_label not in classes.tolist(): + raise ValueError( + f"pos_label={pos_label} is not a valid label: It should be " + f"one of {classes}" + ) + elif pos_label is None and target_type == "binary": + pos_label = pos_label if pos_label is not None else classes[-1] + + if prediction_method.__name__ == "predict_proba": + if target_type == "binary" and y_pred.shape[1] <= 2: + if y_pred.shape[1] == 2: + col_idx = np.flatnonzero(classes == pos_label)[0] + y_pred = y_pred[:, col_idx] + else: + err_msg = ( + f"Got predict_proba of shape {y_pred.shape}, but need " + "classifier with two classes." + ) + raise ValueError(err_msg) + elif prediction_method.__name__ == "decision_function": + if target_type == "binary": + if pos_label == classes[0]: + y_pred *= -1 + else: + if response_method != "predict": + raise ValueError(f"{estimator.__class__.__name__} should be a classifier") + y_pred, pos_label = estimator.predict(X), None + + return y_pred, pos_label diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 64a8229d5a549..d7dea283f4b27 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -8,7 +8,18 @@ import numpy as np import scipy.sparse as sp +from sklearn.datasets import ( + make_classification, + make_regression, +) +from sklearn.linear_model import ( + LinearRegression, + LogisticRegression, +) +from sklearn.tree import DecisionTreeClassifier + from sklearn.utils._testing import ( + assert_allclose, assert_array_equal, assert_allclose_dense_sparse, assert_no_warnings, @@ -30,7 +41,11 @@ from sklearn.utils import is_scalar_nan from sklearn.utils import _to_object_array from sklearn.utils import _approximate_mode -from sklearn.utils._mocking import MockDataFrame +from sklearn.utils import _get_response_values +from sklearn.utils._mocking import ( + _MockEstimatorOnOffPrediction, + MockDataFrame, +) from sklearn import config_context # toy array @@ -720,3 +735,146 @@ def test_to_object_array(sequence): assert isinstance(out, np.ndarray) assert out.dtype.kind == "O" assert out.ndim == 1 + + +@pytest.mark.parametrize("response_method", ["decision_function", "predict_proba"]) +def test_get_response_values_regressor_error(response_method): + """Check the error message with regressor an not supported response + method.""" + my_estimator = _MockEstimatorOnOffPrediction(response_methods=[response_method]) + X, y = "mocking_data", "mocking_target" + err_msg = f"{my_estimator.__class__.__name__} should be a classifier" + with pytest.raises(ValueError, match=err_msg): + _get_response_values(my_estimator, X, y, response_method=response_method) + + +@pytest.mark.parametrize("target_type", [None, "continuous"]) +def test_get_response_values_regressor(target_type): + """Check the behaviour of `_get_response_values` with regressor.""" + X, y = make_regression(n_samples=10, random_state=0) + regressor = LinearRegression().fit(X, y) + y_pred, pos_label = _get_response_values( + regressor, + X, + y, + response_method="predict", + target_type=target_type, + ) + assert_array_equal(y_pred, regressor.predict(X)) + assert pos_label is None + + +@pytest.mark.parametrize( + "response_method", + ["predict_proba", "decision_function", "predict"], +) +def test_get_response_values_classifier_unknown_pos_label(response_method): + """Check that `_get_response_values` raises the proper error message with + classifier.""" + X, y = make_classification(n_samples=10, n_classes=2, random_state=0) + classifier = LogisticRegression().fit(X, y) + + # provide a `pos_label` which is not in `y` + err_msg = r"pos_label=whatever is not a valid label: It should be one of \[0 1\]" + with pytest.raises(ValueError, match=err_msg): + _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label="whatever", + ) + + +def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): + """Check that `_get_response_values` will raise an error when `y_pred` has a + single class with `predict_proba`.""" + X, y_two_class = make_classification(n_samples=10, n_classes=2, random_state=0) + y_single_class = np.zeros_like(y_two_class) + classifier = DecisionTreeClassifier().fit(X, y_single_class) + + err_msg = ( + r"Got predict_proba of shape \(10, 1\), but need classifier with " + r"two classes" + ) + with pytest.raises(ValueError, match=err_msg): + _get_response_values( + classifier, X, y_two_class, response_method="predict_proba" + ) + + +@pytest.mark.parametrize("target_type", [None, "binary"]) +def test_get_response_values_binary_classifier_decision_function(target_type): + """Check the behaviour of `_get_response_values` with `decision_function` + and binary classifier. + """ + X, y = make_classification( + n_samples=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + classifier = LogisticRegression().fit(X, y) + response_method = "decision_function" + + # default `pos_label` + y_pred, pos_label = _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label=None, + target_type=target_type, + ) + assert_allclose(y_pred, classifier.decision_function(X)) + assert pos_label == 1 + + # when forcing `pos_label=classifier.classes_[0]` + y_pred, pos_label = _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label=classifier.classes_[0], + target_type=target_type, + ) + assert_allclose(y_pred, classifier.decision_function(X) * -1) + assert pos_label == 0 + + +@pytest.mark.parametrize("target_type", [None, "binary"]) +def test_get_response_values_binary_classifier_predict_proba(target_type): + """Check that `_get_response_values` with `predict_proba` and binary + classifier.""" + X, y = make_classification( + n_samples=10, + n_classes=2, + weights=[0.3, 0.7], + random_state=0, + ) + classifier = LogisticRegression().fit(X, y) + response_method = "predict_proba" + + # default `pos_label` + y_pred, pos_label = _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label=None, + target_type=target_type, + ) + assert_allclose(y_pred, classifier.predict_proba(X)[:, 1]) + assert pos_label == 1 + + # when forcing `pos_label=classifier.classes_[0]` + y_pred, pos_label = _get_response_values( + classifier, + X, + y, + response_method=response_method, + pos_label=classifier.classes_[0], + target_type=target_type, + ) + assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) + assert pos_label == 0 From cab4681fe8fc63d9851b3f5f5280b25a5cab2782 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 7 Apr 2022 18:20:13 +0200 Subject: [PATCH 03/15] iter --- sklearn/utils/tests/test_utils.py | 3 +-- sklearn/utils/validation.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index d7dea283f4b27..1f24528f3240d 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -806,8 +806,7 @@ def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): @pytest.mark.parametrize("target_type", [None, "binary"]) def test_get_response_values_binary_classifier_decision_function(target_type): """Check the behaviour of `_get_response_values` with `decision_function` - and binary classifier. - """ + and binary classifier.""" X, y = make_classification( n_samples=10, n_classes=2, diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 8c72dc30fdbef..b31e701d2969f 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1778,6 +1778,7 @@ def _check_response_method(estimator, response_method): ---------- estimator : estimator instance Classifier or regressor to check. + response_method : {"predict_proba", "decision_function", "predict"} or \ list of such str Specifies the response method to use get prediction from an estimator From fd4926229bd907b5b9151308b9896f7ede6bf058 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 22 Feb 2023 12:16:41 +0100 Subject: [PATCH 04/15] update versionadded --- sklearn/utils/validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 1a7c98b256a88..5520aebd58bf4 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1830,7 +1830,7 @@ def _allclose_dense_sparse(x, y, rtol=1e-7, atol=1e-9): def _check_response_method(estimator, response_method): """Check if `response_method` is available in estimator and return it. - .. versionadded:: 1.1 + .. versionadded:: 1.3 Parameters ---------- From 930671162063f2912999cb75dd34f279063e52dd Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 22 Feb 2023 13:09:23 +0100 Subject: [PATCH 05/15] TST add __name__ to mocked method --- sklearn/ensemble/tests/test_stacking.py | 8 ++++++-- sklearn/utils/_response.py | 6 ++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index f237961ed7606..fa15c2411a7e1 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -572,9 +572,13 @@ def test_stacking_prefit(Stacker, Estimator, stack_method, final_estimator, X, y # mock out fit and stack_method to be asserted later for _, estimator in estimators: - estimator.fit = Mock() + estimator.fit = Mock(name="fit") stack_func = getattr(estimator, stack_method) - setattr(estimator, stack_method, Mock(side_effect=stack_func)) + predict_method_mocked = Mock(side_effect=stack_func, name=stack_method) + # Mocking a method will not provide an `__name__` while Python methods + # do and we are using it in `_get_response_method`. + predict_method_mocked.__name__ = stack_method + setattr(estimator, stack_method, predict_method_mocked) stacker = Stacker( estimators=estimators, cv="prefit", final_estimator=final_estimator diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 2dfdeb4593377..0b08b2cde3423 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -87,6 +87,12 @@ def _get_response_values( y_pred = prediction_method(X) classes = estimator.classes_ + if target_type == "multiclass" and prediction_method.__name__ != "predict": + raise ValueError( + "With multiclass target, the response method should be " + f"predict, got {prediction_method.__name__} instead." + ) + if pos_label is not None and pos_label not in classes.tolist(): raise ValueError( f"pos_label={pos_label} is not a valid label: It should be " From 0b16e00c3dd2c11f7a5355caa2abc36033096d2d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 6 Mar 2023 04:08:05 -0500 Subject: [PATCH 06/15] ENH Removing y_true and type_of_target from _get_response_values (#13) --- sklearn/calibration.py | 7 +- sklearn/metrics/_plot/base.py | 92 ++++--------------- sklearn/metrics/_plot/det_curve.py | 6 +- .../metrics/_plot/precision_recall_curve.py | 6 +- sklearn/metrics/_plot/roc_curve.py | 6 +- sklearn/metrics/_plot/tests/test_base.py | 51 +++++----- .../_plot/tests/test_common_curve_display.py | 12 +-- sklearn/tests/test_calibration.py | 4 +- sklearn/utils/_response.py | 19 +--- sklearn/utils/tests/test_response.py | 28 ++---- 10 files changed, 73 insertions(+), 158 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index c91a5dc0031b0..8603fa13cd783 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -1292,17 +1292,18 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - target_type = type_of_target(y) if not is_classifier(estimator): raise ValueError("'estimator' should be a fitted classifier.") + check_is_fitted(estimator) + if len(estimator.classes_) != 2: + raise ValueError("Estimator must be a binary classifier.") + y_prob, pos_label = _get_response_values( estimator, X, - y, response_method="predict_proba", pos_label=pos_label, - target_type=target_type, ) name = name if name is not None else estimator.__class__.__name__ diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py index 60377e3b10f66..32820db8fec01 100644 --- a/sklearn/metrics/_plot/base.py +++ b/sklearn/metrics/_plot/base.py @@ -1,64 +1,20 @@ from ...base import is_classifier +from ...utils._response import _get_response_values +from ...utils.validation import check_is_fitted -def _check_classifier_response_method(estimator, response_method): - """Return prediction method from the response_method - - Parameters - ---------- - estimator: object - Classifier to check - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - Returns - ------- - prediction_method: callable - prediction method of estimator - """ - - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError( - "response_method must be 'predict_proba', 'decision_function' or 'auto'" - ) - - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError( - error_msg.format(response_method, estimator.__class__.__name__) - ) - else: - predict_proba = getattr(estimator, "predict_proba", None) - decision_function = getattr(estimator, "decision_function", None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError( - error_msg.format( - "decision_function or predict_proba", estimator.__class__.__name__ - ) - ) - - return prediction_method - - -def _get_response(X, estimator, response_method, pos_label=None): +def _get_response_binary(estimator, X, response_method, pos_label=None): """Return response and positive label. Parameters ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input values. - estimator : estimator instance Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` in which the last estimator is a classifier. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + response_method: {'auto', 'predict_proba', 'decision_function'} Specifies whether to use :term:`predict_proba` or :term:`decision_function` as the target response. If set to 'auto', @@ -85,32 +41,16 @@ def _get_response(X, estimator, response_method, pos_label=None): f" {estimator.__class__.__name__}" ) - if not is_classifier(estimator): + check_is_fitted(estimator) + if not is_classifier(estimator) or len(estimator.classes_) != 2: raise ValueError(classification_error) - prediction_method = _check_classifier_response_method(estimator, response_method) - y_pred = prediction_method(X) - if pos_label is not None: - try: - class_idx = estimator.classes_.tolist().index(pos_label) - except ValueError as e: - raise ValueError( - "The class provided by 'pos_label' is unknown. Got " - f"{pos_label} instead of one of {set(estimator.classes_)}" - ) from e - else: - class_idx = 1 - pos_label = estimator.classes_[class_idx] - - if y_pred.ndim != 1: # `predict_proba` - y_pred_shape = y_pred.shape[1] - if y_pred_shape != 2: - raise ValueError( - f"{classification_error} fit on multiclass ({y_pred_shape} classes)" - " data" - ) - y_pred = y_pred[:, class_idx] - elif pos_label == estimator.classes_[0]: # `decision_function` - y_pred *= -1 + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] - return y_pred, pos_label + return _get_response_values( + estimator, + X, + response_method, + pos_label=pos_label, + ) diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index b4a868b195dd0..2934746ae1733 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -1,6 +1,6 @@ import scipy as sp -from .base import _get_response +from .base import _get_response_binary from .. import det_curve from .._base import _check_pos_label_consistency @@ -168,9 +168,9 @@ def from_estimator( name = estimator.__class__.__name__ if name is None else name - y_pred, pos_label = _get_response( - X, + y_pred, pos_label = _get_response_binary( estimator, + X, response_method, pos_label=pos_label, ) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 35cf72e618e84..d53ebec27569f 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,5 +1,5 @@ from sklearn.base import is_classifier -from .base import _get_response +from .base import _get_response_binary from .. import average_precision_score from .. import precision_recall_curve @@ -271,9 +271,9 @@ def from_estimator( check_matplotlib_support(method_name) if not is_classifier(estimator): raise ValueError(f"{method_name} only supports classifiers") - y_pred, pos_label = _get_response( - X, + y_pred, pos_label = _get_response_binary( estimator, + X, response_method, pos_label=pos_label, ) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 256183787e470..4b37681dbee74 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,4 +1,4 @@ -from .base import _get_response +from .base import _get_response_binary from .. import auc from .. import roc_curve @@ -231,9 +231,9 @@ def from_estimator( name = estimator.__class__.__name__ if name is None else name - y_pred, pos_label = _get_response( - X, + y_pred, pos_label = _get_response_binary( estimator, + X, response_method=response_method, pos_label=pos_label, ) diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py index 2f67d7dd223f4..302b41b0f8c96 100644 --- a/sklearn/metrics/_plot/tests/test_base.py +++ b/sklearn/metrics/_plot/tests/test_base.py @@ -5,71 +5,74 @@ from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -from sklearn.metrics._plot.base import _get_response +from sklearn.metrics._plot.base import _get_response_binary + +X, y = load_iris(return_X_y=True) +X_binary, y_binary = X[:100], y[:100] @pytest.mark.parametrize( - "estimator, err_msg, params", + "estimator, X, y, err_msg, params", [ ( DecisionTreeRegressor(), + X_binary, + y_binary, "Expected 'estimator' to be a binary classifier", {"response_method": "auto"}, ), ( DecisionTreeClassifier(), - "The class provided by 'pos_label' is unknown.", + X_binary, + y_binary, + r"pos_label=unknown is not a valid label: It should be one of \[0 1\]", {"response_method": "auto", "pos_label": "unknown"}, ), ( DecisionTreeClassifier(), - "fit on multiclass", + X, + y, + "Expected 'estimator' to be a binary classifier, but got" + " DecisionTreeClassifier", {"response_method": "predict_proba"}, ), ], ) -def test_get_response_error(estimator, err_msg, params): - """Check that we raise the proper error messages in `_get_response`.""" - X, y = load_iris(return_X_y=True) +def test_get_response_error(estimator, X, y, err_msg, params): + """Check that we raise the proper error messages in `_get_response_binary`.""" estimator.fit(X, y) with pytest.raises(ValueError, match=err_msg): - _get_response(X, estimator, **params) + _get_response_binary(estimator, X, **params) def test_get_response_predict_proba(): - """Check the behaviour of `_get_response` using `predict_proba`.""" - X, y = load_iris(return_X_y=True) - X_binary, y_binary = X[:100], y[:100] - + """Check the behaviour of `_get_response_binary` using `predict_proba`.""" classifier = DecisionTreeClassifier().fit(X_binary, y_binary) - y_proba, pos_label = _get_response( - X_binary, classifier, response_method="predict_proba" + y_proba, pos_label = _get_response_binary( + classifier, X_binary, response_method="predict_proba" ) np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) assert pos_label == 1 - y_proba, pos_label = _get_response( - X_binary, classifier, response_method="predict_proba", pos_label=0 + y_proba, pos_label = _get_response_binary( + classifier, X_binary, response_method="predict_proba", pos_label=0 ) np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) assert pos_label == 0 def test_get_response_decision_function(): - """Check the behaviour of `get_response` using `decision_function`.""" - X, y = load_iris(return_X_y=True) - X_binary, y_binary = X[:100], y[:100] - + """Check the behaviour of `_get_response_binary` using `decision_function`.""" classifier = LogisticRegression().fit(X_binary, y_binary) - y_score, pos_label = _get_response( - X_binary, classifier, response_method="decision_function" + y_score, pos_label = _get_response_binary( + classifier, X_binary, response_method="decision_function" ) np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) assert pos_label == 1 - y_score, pos_label = _get_response( - X_binary, classifier, response_method="decision_function", pos_label=0 + y_score, pos_label = _get_response_binary( + classifier, X_binary, response_method="decision_function", pos_label=0 ) np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) assert pos_label == 0 diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 5ed036b77f4d0..9dadce3eb6a9f 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -48,20 +48,20 @@ def test_display_curve_error_non_binary(pyplot, data, Display): [ ( "predict_proba", - "response method predict_proba is not defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba.", ), ( "decision_function", - "response method decision_function is not defined in MyClassifier", + "MyClassifier has none of the following attributes: decision_function.", ), ( "auto", - "response method decision_function or predict_proba is not " - "defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba," + " decision_function.", ), ( "bad_method", - "response_method must be 'predict_proba', 'decision_function' or 'auto'", + "MyClassifier has none of the following attributes: bad_method.", ), ], ) @@ -86,7 +86,7 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(ValueError, match=msg): + with pytest.raises(AttributeError, match=msg): Display.from_estimator(clf, X, y, response_method=response_method) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index 1127f5f0948f8..1f610d80409a8 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -624,7 +624,7 @@ def test_calibration_display_validation(pyplot, iris_data, iris_data_binary): with pytest.raises(ValueError, match=msg): CalibrationDisplay.from_estimator(reg, X, y) - clf = LinearSVC().fit(X, y) + clf = LinearSVC().fit(X_binary, y_binary) msg = "has none of the following attributes: predict_proba." with pytest.raises(AttributeError, match=msg): CalibrationDisplay.from_estimator(clf, X, y) @@ -642,7 +642,7 @@ def test_calibration_display_non_binary(pyplot, iris_data, constructor_name): y_prob = clf.predict_proba(X) if constructor_name == "from_estimator": - msg = "The target y is not binary. Got multiclass type of target." + msg = "Estimator must be a binary classifier." with pytest.raises(ValueError, match=msg): CalibrationDisplay.from_estimator(clf, X, y) else: diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 0b08b2cde3423..3ea49a34eb6c3 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -1,16 +1,13 @@ import numpy as np -from .multiclass import type_of_target from .validation import _check_response_method def _get_response_values( estimator, X, - y_true, response_method, pos_label=None, - target_type=None, ): """Compute the response values of a classifier or a regressor. @@ -34,9 +31,6 @@ def _get_response_values( X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. - y_true : array-like of shape (n_samples,) - The true label. - response_method : {"predict_proba", "decision_function", "predict"} or \ list of such str Specifies the response method to use get prediction from an estimator @@ -53,13 +47,6 @@ def _get_response_values( the metrics. By default, `estimators.classes_[1]` is considered as the positive class. - target_type : str, default=None - The type of the target `y` as returned by - :func:`~sklearn.utils.multiclass.type_of_target`. If `None`, the type - will be inferred by calling :func:`~sklearn.utils.multiclass.type_of_target`. - Providing the type of the target could save time by avoid calling the - :func:`~sklearn.utils.multiclass.type_of_target` function. - Returns ------- y_pred : ndarray of shape (n_samples,) @@ -81,15 +68,15 @@ def _get_response_values( from sklearn.base import is_classifier # noqa if is_classifier(estimator): - if target_type is None: - target_type = type_of_target(y_true) prediction_method = _check_response_method(estimator, response_method) y_pred = prediction_method(X) classes = estimator.classes_ + target_type = "binary" if len(classes) <= 2 else "multiclass" + if target_type == "multiclass" and prediction_method.__name__ != "predict": raise ValueError( - "With multiclass target, the response method should be " + "With a multiclass estimator, the response method should be " f"predict, got {prediction_method.__name__} instead." ) diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index ec8440630c6c4..a61e935429c6c 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -18,23 +18,20 @@ def test_get_response_values_regressor_error(response_method): """Check the error message with regressor an not supported response method.""" my_estimator = _MockEstimatorOnOffPrediction(response_methods=[response_method]) - X, y = "mocking_data", "mocking_target" + X = "mocking_data", "mocking_target" err_msg = f"{my_estimator.__class__.__name__} should be a classifier" with pytest.raises(ValueError, match=err_msg): - _get_response_values(my_estimator, X, y, response_method=response_method) + _get_response_values(my_estimator, X, response_method=response_method) -@pytest.mark.parametrize("target_type", [None, "continuous"]) -def test_get_response_values_regressor(target_type): +def test_get_response_values_regressor(): """Check the behaviour of `_get_response_values` with regressor.""" X, y = make_regression(n_samples=10, random_state=0) regressor = LinearRegression().fit(X, y) y_pred, pos_label = _get_response_values( regressor, X, - y, response_method="predict", - target_type=target_type, ) assert_array_equal(y_pred, regressor.predict(X)) assert pos_label is None @@ -56,7 +53,6 @@ def test_get_response_values_classifier_unknown_pos_label(response_method): _get_response_values( classifier, X, - y, response_method=response_method, pos_label="whatever", ) @@ -74,13 +70,10 @@ def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): r"two classes" ) with pytest.raises(ValueError, match=err_msg): - _get_response_values( - classifier, X, y_two_class, response_method="predict_proba" - ) + _get_response_values(classifier, X, response_method="predict_proba") -@pytest.mark.parametrize("target_type", [None, "binary"]) -def test_get_response_values_binary_classifier_decision_function(target_type): +def test_get_response_values_binary_classifier_decision_function(): """Check the behaviour of `_get_response_values` with `decision_function` and binary classifier.""" X, y = make_classification( @@ -96,10 +89,8 @@ def test_get_response_values_binary_classifier_decision_function(target_type): y_pred, pos_label = _get_response_values( classifier, X, - y, response_method=response_method, pos_label=None, - target_type=target_type, ) assert_allclose(y_pred, classifier.decision_function(X)) assert pos_label == 1 @@ -108,17 +99,14 @@ def test_get_response_values_binary_classifier_decision_function(target_type): y_pred, pos_label = _get_response_values( classifier, X, - y, response_method=response_method, pos_label=classifier.classes_[0], - target_type=target_type, ) assert_allclose(y_pred, classifier.decision_function(X) * -1) assert pos_label == 0 -@pytest.mark.parametrize("target_type", [None, "binary"]) -def test_get_response_values_binary_classifier_predict_proba(target_type): +def test_get_response_values_binary_classifier_predict_proba(): """Check that `_get_response_values` with `predict_proba` and binary classifier.""" X, y = make_classification( @@ -134,10 +122,8 @@ def test_get_response_values_binary_classifier_predict_proba(target_type): y_pred, pos_label = _get_response_values( classifier, X, - y, response_method=response_method, pos_label=None, - target_type=target_type, ) assert_allclose(y_pred, classifier.predict_proba(X)[:, 1]) assert pos_label == 1 @@ -146,10 +132,8 @@ def test_get_response_values_binary_classifier_predict_proba(target_type): y_pred, pos_label = _get_response_values( classifier, X, - y, response_method=response_method, pos_label=classifier.classes_[0], - target_type=target_type, ) assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) assert pos_label == 0 From d97a4c7d86c374e25ac122fd9bcfa0bc2dd601fa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 6 Mar 2023 10:39:17 +0100 Subject: [PATCH 07/15] TST remove unecessary method namingt --- sklearn/ensemble/tests/test_stacking.py | 2 +- sklearn/utils/_response.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index fa15c2411a7e1..1ca134ea77923 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -574,7 +574,7 @@ def test_stacking_prefit(Stacker, Estimator, stack_method, final_estimator, X, y for _, estimator in estimators: estimator.fit = Mock(name="fit") stack_func = getattr(estimator, stack_method) - predict_method_mocked = Mock(side_effect=stack_func, name=stack_method) + predict_method_mocked = Mock(side_effect=stack_func) # Mocking a method will not provide an `__name__` while Python methods # do and we are using it in `_get_response_method`. predict_method_mocked.__name__ = stack_method diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 3ea49a34eb6c3..1da33f4783158 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -69,7 +69,6 @@ def _get_response_values( if is_classifier(estimator): prediction_method = _check_response_method(estimator, response_method) - y_pred = prediction_method(X) classes = estimator.classes_ target_type = "binary" if len(classes) <= 2 else "multiclass" @@ -88,6 +87,7 @@ def _get_response_values( elif pos_label is None and target_type == "binary": pos_label = pos_label if pos_label is not None else classes[-1] + y_pred = prediction_method(X) if prediction_method.__name__ == "predict_proba": if target_type == "binary" and y_pred.shape[1] <= 2: if y_pred.shape[1] == 2: From 79bfc3cf2af79a5d39effc397575d6271319c93b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 17 Mar 2023 15:17:49 +0100 Subject: [PATCH 08/15] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> --- sklearn/ensemble/tests/test_stacking.py | 2 +- sklearn/utils/_response.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index 1ca134ea77923..956997156b455 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -575,7 +575,7 @@ def test_stacking_prefit(Stacker, Estimator, stack_method, final_estimator, X, y estimator.fit = Mock(name="fit") stack_func = getattr(estimator, stack_method) predict_method_mocked = Mock(side_effect=stack_func) - # Mocking a method will not provide an `__name__` while Python methods + # Mocking a method will not provide a `__name__` while Python methods # do and we are using it in `_get_response_method`. predict_method_mocked.__name__ = stack_method setattr(estimator, stack_method, predict_method_mocked) diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 1da33f4783158..3169991ebe83d 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -103,7 +103,7 @@ def _get_response_values( if target_type == "binary": if pos_label == classes[0]: y_pred *= -1 - else: + else: # estimator is a regressor if response_method != "predict": raise ValueError(f"{estimator.__class__.__name__} should be a classifier") y_pred, pos_label = estimator.predict(X), None From a54d7752a649b4fbe85e3af06d3985ffda67a6b8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 17 Mar 2023 15:56:50 +0100 Subject: [PATCH 09/15] address jeremie comments --- sklearn/calibration.py | 12 +-- sklearn/metrics/_plot/base.py | 56 ------------- sklearn/metrics/_plot/det_curve.py | 5 +- .../metrics/_plot/precision_recall_curve.py | 6 +- sklearn/metrics/_plot/roc_curve.py | 5 +- sklearn/metrics/_plot/tests/test_base.py | 78 ------------------- sklearn/tests/test_calibration.py | 2 +- sklearn/utils/_response.py | 69 +++++++++++++++- sklearn/utils/tests/test_response.py | 78 ++++++++++++++++++- 9 files changed, 153 insertions(+), 158 deletions(-) delete mode 100644 sklearn/metrics/_plot/base.py delete mode 100644 sklearn/metrics/_plot/tests/test_base.py diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 220ca2d7712ab..a0089e131f8dd 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -34,7 +34,7 @@ check_matplotlib_support, _safe_indexing, ) -from .utils._response import _get_response_values +from .utils._response import _get_response_values_binary from .utils.multiclass import check_classification_targets, type_of_target from .utils.parallel import delayed, Parallel @@ -1268,14 +1268,8 @@ def from_estimator( raise ValueError("'estimator' should be a fitted classifier.") check_is_fitted(estimator) - if len(estimator.classes_) != 2: - raise ValueError("Estimator must be a binary classifier.") - - y_prob, pos_label = _get_response_values( - estimator, - X, - response_method="predict_proba", - pos_label=pos_label, + y_prob, pos_label = _get_response_values_binary( + estimator, X, response_method="predict_proba", pos_label=pos_label ) name = name if name is not None else estimator.__class__.__name__ diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py deleted file mode 100644 index 32820db8fec01..0000000000000 --- a/sklearn/metrics/_plot/base.py +++ /dev/null @@ -1,56 +0,0 @@ -from ...base import is_classifier -from ...utils._response import _get_response_values -from ...utils.validation import check_is_fitted - - -def _get_response_binary(estimator, X, response_method, pos_label=None): - """Return response and positive label. - - Parameters - ---------- - estimator : estimator instance - Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` - in which the last estimator is a classifier. - - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Input values. - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - pos_label : str or int, default=None - The class considered as the positive class when computing - the metrics. By default, `estimators.classes_[1]` is - considered as the positive class. - - Returns - ------- - y_pred: ndarray of shape (n_samples,) - Target scores calculated from the provided response_method - and pos_label. - - pos_label: str or int - The class considered as the positive class when computing - the metrics. - """ - classification_error = ( - "Expected 'estimator' to be a binary classifier, but got" - f" {estimator.__class__.__name__}" - ) - - check_is_fitted(estimator) - if not is_classifier(estimator) or len(estimator.classes_) != 2: - raise ValueError(classification_error) - - if response_method == "auto": - response_method = ["predict_proba", "decision_function"] - - return _get_response_values( - estimator, - X, - response_method, - pos_label=pos_label, - ) diff --git a/sklearn/metrics/_plot/det_curve.py b/sklearn/metrics/_plot/det_curve.py index 2934746ae1733..f9832fed41847 100644 --- a/sklearn/metrics/_plot/det_curve.py +++ b/sklearn/metrics/_plot/det_curve.py @@ -1,11 +1,10 @@ import scipy as sp -from .base import _get_response_binary - from .. import det_curve from .._base import _check_pos_label_consistency from ...utils import check_matplotlib_support +from ...utils._response import _get_response_values_binary class DetCurveDisplay: @@ -168,7 +167,7 @@ def from_estimator( name = estimator.__class__.__name__ if name is None else name - y_pred, pos_label = _get_response_binary( + y_pred, pos_label = _get_response_values_binary( estimator, X, response_method, diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index d53ebec27569f..9195c6512df9a 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,5 +1,4 @@ -from sklearn.base import is_classifier -from .base import _get_response_binary +from ...base import is_classifier from .. import average_precision_score from .. import precision_recall_curve @@ -7,6 +6,7 @@ from .._classification import check_consistent_length from ...utils import check_matplotlib_support +from ...utils._response import _get_response_values_binary class PrecisionRecallDisplay: @@ -271,7 +271,7 @@ def from_estimator( check_matplotlib_support(method_name) if not is_classifier(estimator): raise ValueError(f"{method_name} only supports classifiers") - y_pred, pos_label = _get_response_binary( + y_pred, pos_label = _get_response_values_binary( estimator, X, response_method, diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 4b37681dbee74..65d639679449d 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,10 +1,9 @@ -from .base import _get_response_binary - from .. import auc from .. import roc_curve from .._base import _check_pos_label_consistency from ...utils import check_matplotlib_support +from ...utils._response import _get_response_values_binary class RocCurveDisplay: @@ -231,7 +230,7 @@ def from_estimator( name = estimator.__class__.__name__ if name is None else name - y_pred, pos_label = _get_response_binary( + y_pred, pos_label = _get_response_values_binary( estimator, X, response_method=response_method, diff --git a/sklearn/metrics/_plot/tests/test_base.py b/sklearn/metrics/_plot/tests/test_base.py deleted file mode 100644 index 302b41b0f8c96..0000000000000 --- a/sklearn/metrics/_plot/tests/test_base.py +++ /dev/null @@ -1,78 +0,0 @@ -import numpy as np -import pytest - -from sklearn.datasets import load_iris -from sklearn.linear_model import LogisticRegression -from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor - -from sklearn.metrics._plot.base import _get_response_binary - -X, y = load_iris(return_X_y=True) -X_binary, y_binary = X[:100], y[:100] - - -@pytest.mark.parametrize( - "estimator, X, y, err_msg, params", - [ - ( - DecisionTreeRegressor(), - X_binary, - y_binary, - "Expected 'estimator' to be a binary classifier", - {"response_method": "auto"}, - ), - ( - DecisionTreeClassifier(), - X_binary, - y_binary, - r"pos_label=unknown is not a valid label: It should be one of \[0 1\]", - {"response_method": "auto", "pos_label": "unknown"}, - ), - ( - DecisionTreeClassifier(), - X, - y, - "Expected 'estimator' to be a binary classifier, but got" - " DecisionTreeClassifier", - {"response_method": "predict_proba"}, - ), - ], -) -def test_get_response_error(estimator, X, y, err_msg, params): - """Check that we raise the proper error messages in `_get_response_binary`.""" - - estimator.fit(X, y) - with pytest.raises(ValueError, match=err_msg): - _get_response_binary(estimator, X, **params) - - -def test_get_response_predict_proba(): - """Check the behaviour of `_get_response_binary` using `predict_proba`.""" - classifier = DecisionTreeClassifier().fit(X_binary, y_binary) - y_proba, pos_label = _get_response_binary( - classifier, X_binary, response_method="predict_proba" - ) - np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) - assert pos_label == 1 - - y_proba, pos_label = _get_response_binary( - classifier, X_binary, response_method="predict_proba", pos_label=0 - ) - np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) - assert pos_label == 0 - - -def test_get_response_decision_function(): - """Check the behaviour of `_get_response_binary` using `decision_function`.""" - classifier = LogisticRegression().fit(X_binary, y_binary) - y_score, pos_label = _get_response_binary( - classifier, X_binary, response_method="decision_function" - ) - np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) - assert pos_label == 1 - - y_score, pos_label = _get_response_binary( - classifier, X_binary, response_method="decision_function", pos_label=0 - ) - np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) - assert pos_label == 0 diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index 517b735e19d86..b4a8f84e8717a 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -622,7 +622,7 @@ def test_calibration_display_non_binary(pyplot, iris_data, constructor_name): y_prob = clf.predict_proba(X) if constructor_name == "from_estimator": - msg = "Estimator must be a binary classifier." + msg = "to be a binary classifier. Got 3 classes instead." with pytest.raises(ValueError, match=msg): CalibrationDisplay.from_estimator(clf, X, y) else: diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index 3169991ebe83d..a1c8af33f7cb3 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -1,6 +1,11 @@ +"""Utilities to get the response values of a classifier or a regressor. + +It allows to make uniform checks and validation. +""" import numpy as np -from .validation import _check_response_method +from ..base import is_classifier +from .validation import _check_response_method, check_is_fitted def _get_response_values( @@ -105,7 +110,69 @@ def _get_response_values( y_pred *= -1 else: # estimator is a regressor if response_method != "predict": + raise ValueError( + f"{estimator.__class__.__name__} should either be a classifier to be " + f"used with response_method={response_method} or the response_method " + "should be 'predict'. Got a regressor with response_method=" + f"{response_method} instead." + ) raise ValueError(f"{estimator.__class__.__name__} should be a classifier") y_pred, pos_label = estimator.predict(X), None return y_pred, pos_label + + +def _get_response_values_binary(estimator, X, response_method, pos_label=None): + """Compute the response values of a binary classifier. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a binary classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + response_method: {'auto', 'predict_proba', 'decision_function'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + pos_label : str or int, default=None + The class considered as the positive class when computing + the metrics. By default, `estimators.classes_[1]` is + considered as the positive class. + + Returns + ------- + y_pred: ndarray of shape (n_samples,) + Target scores calculated from the provided response_method + and pos_label. + + pos_label: str or int + The class considered as the positive class when computing + the metrics. + """ + classification_error = "Expected 'estimator' to be a binary classifier." + + check_is_fitted(estimator) + if not is_classifier(estimator): + raise ValueError( + classification_error + f" Got {estimator.__class__.__name__} instead." + ) + elif len(estimator.classes_) != 2: + raise ValueError( + classification_error + f" Got {len(estimator.classes_)} classes instead." + ) + + if response_method == "auto": + response_method = ["predict_proba", "decision_function"] + + return _get_response_values( + estimator, + X, + response_method, + pos_label=pos_label, + ) diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index a61e935429c6c..9d6d90ddd94ae 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -1,16 +1,20 @@ import numpy as np import pytest -from sklearn.datasets import make_classification, make_regression +from sklearn.datasets import load_iris, make_classification, make_regression from sklearn.linear_model import ( LinearRegression, LogisticRegression, ) -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils._mocking import _MockEstimatorOnOffPrediction from sklearn.utils._testing import assert_allclose, assert_array_equal -from sklearn.utils._response import _get_response_values +from sklearn.utils._response import _get_response_values, _get_response_values_binary + + +X, y = load_iris(return_X_y=True) +X_binary, y_binary = X[:100], y[:100] @pytest.mark.parametrize("response_method", ["decision_function", "predict_proba"]) @@ -19,7 +23,7 @@ def test_get_response_values_regressor_error(response_method): method.""" my_estimator = _MockEstimatorOnOffPrediction(response_methods=[response_method]) X = "mocking_data", "mocking_target" - err_msg = f"{my_estimator.__class__.__name__} should be a classifier" + err_msg = f"{my_estimator.__class__.__name__} should either be a classifier" with pytest.raises(ValueError, match=err_msg): _get_response_values(my_estimator, X, response_method=response_method) @@ -137,3 +141,69 @@ def test_get_response_values_binary_classifier_predict_proba(): ) assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) assert pos_label == 0 + + +@pytest.mark.parametrize( + "estimator, X, y, err_msg, params", + [ + ( + DecisionTreeRegressor(), + X_binary, + y_binary, + "Expected 'estimator' to be a binary classifier", + {"response_method": "auto"}, + ), + ( + DecisionTreeClassifier(), + X_binary, + y_binary, + r"pos_label=unknown is not a valid label: It should be one of \[0 1\]", + {"response_method": "auto", "pos_label": "unknown"}, + ), + ( + DecisionTreeClassifier(), + X, + y, + "be a binary classifier. Got 3 classes instead.", + {"response_method": "predict_proba"}, + ), + ], +) +def test_get_response_error(estimator, X, y, err_msg, params): + """Check that we raise the proper error messages in _get_response_values_binary.""" + + estimator.fit(X, y) + with pytest.raises(ValueError, match=err_msg): + _get_response_values_binary(estimator, X, **params) + + +def test_get_response_predict_proba(): + """Check the behaviour of `_get_response_values_binary` using `predict_proba`.""" + classifier = DecisionTreeClassifier().fit(X_binary, y_binary) + y_proba, pos_label = _get_response_values_binary( + classifier, X_binary, response_method="predict_proba" + ) + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) + assert pos_label == 1 + + y_proba, pos_label = _get_response_values_binary( + classifier, X_binary, response_method="predict_proba", pos_label=0 + ) + np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) + assert pos_label == 0 + + +def test_get_response_decision_function(): + """Check the behaviour of `_get_response_values_binary` using decision_function.""" + classifier = LogisticRegression().fit(X_binary, y_binary) + y_score, pos_label = _get_response_values_binary( + classifier, X_binary, response_method="decision_function" + ) + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) + assert pos_label == 1 + + y_score, pos_label = _get_response_values_binary( + classifier, X_binary, response_method="decision_function", pos_label=0 + ) + np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) + assert pos_label == 0 From e27e990dc9da2ba427147b85c37e0c1649c7ad8f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 20 Mar 2023 10:20:44 +0100 Subject: [PATCH 10/15] fix match string for error --- sklearn/metrics/_plot/tests/test_common_curve_display.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 9dadce3eb6a9f..51e9cea338d95 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -36,9 +36,7 @@ def test_display_curve_error_non_binary(pyplot, data, Display): X, y = data clf = DecisionTreeClassifier().fit(X, y) - msg = ( - "Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier" - ) + msg = "Expected 'estimator' to be a binary classifier." with pytest.raises(ValueError, match=msg): Display.from_estimator(clf, X, y) From 43094a156e3b351066eb17bc15027b5f39f5cc71 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 20 Mar 2023 10:48:08 +0100 Subject: [PATCH 11/15] iter --- sklearn/metrics/_plot/tests/test_precision_recall_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 4d514fa1f32b3..2b2f7c439b660 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -39,7 +39,7 @@ def test_precision_recall_display_validation(pyplot): with pytest.raises(ValueError, match=err_msg): PrecisionRecallDisplay.from_estimator(regressor, X, y) - err_msg = "Expected 'estimator' to be a binary classifier, but got SVC" + err_msg = "Expected 'estimator' to be a binary classifier." with pytest.raises(ValueError, match=err_msg): PrecisionRecallDisplay.from_estimator(classifier, X, y) From 68ceba6b1dbab0b1c6a5981b63cfadc1b550f7f1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 23 Mar 2023 15:19:04 +0100 Subject: [PATCH 12/15] address jeremie comments --- sklearn/calibration.py | 4 ---- sklearn/metrics/_plot/precision_recall_curve.py | 5 +---- sklearn/metrics/_plot/tests/test_common_curve_display.py | 2 +- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index a0089e131f8dd..31f8b67458f78 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -25,7 +25,6 @@ RegressorMixin, clone, MetaEstimatorMixin, - is_classifier, ) from .preprocessing import label_binarize, LabelEncoder from .utils import ( @@ -1264,9 +1263,6 @@ def from_estimator( method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - if not is_classifier(estimator): - raise ValueError("'estimator' should be a fitted classifier.") - check_is_fitted(estimator) y_prob, pos_label = _get_response_values_binary( estimator, X, response_method="predict_proba", pos_label=pos_label diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 9195c6512df9a..3ab49b737facd 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,5 +1,3 @@ -from ...base import is_classifier - from .. import average_precision_score from .. import precision_recall_curve from .._base import _check_pos_label_consistency @@ -269,8 +267,7 @@ def from_estimator( """ method_name = f"{cls.__name__}.from_estimator" check_matplotlib_support(method_name) - if not is_classifier(estimator): - raise ValueError(f"{method_name} only supports classifiers") + y_pred, pos_label = _get_response_values_binary( estimator, X, diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 51e9cea338d95..27730893bb05c 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -36,7 +36,7 @@ def test_display_curve_error_non_binary(pyplot, data, Display): X, y = data clf = DecisionTreeClassifier().fit(X, y) - msg = "Expected 'estimator' to be a binary classifier." + msg = "Expected 'estimator' to be a binary classifier. Got 3 classes instead." with pytest.raises(ValueError, match=msg): Display.from_estimator(clf, X, y) From bbaa52018c2893e6fb501402b9c22f51dc2f2d82 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 23 Mar 2023 15:41:23 +0100 Subject: [PATCH 13/15] update regex for test --- sklearn/metrics/_plot/tests/test_precision_recall_display.py | 2 +- sklearn/tests/test_calibration.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 2b2f7c439b660..4b5c0c989e27f 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -35,7 +35,7 @@ def test_precision_recall_display_validation(pyplot): classifier = SVC(probability=True).fit(X, y) y_pred_classifier = classifier.predict_proba(X)[:, -1] - err_msg = "PrecisionRecallDisplay.from_estimator only supports classifiers" + err_msg = "Expected 'estimator' to be a binary classifier. Got SVR instead." with pytest.raises(ValueError, match=err_msg): PrecisionRecallDisplay.from_estimator(regressor, X, y) diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index b4a8f84e8717a..fff774c3fc490 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -600,7 +600,7 @@ def test_calibration_display_validation(pyplot, iris_data, iris_data_binary): X_binary, y_binary = iris_data_binary reg = LinearRegression().fit(X, y) - msg = "'estimator' should be a fitted classifier" + msg = "Expected 'estimator' to be a binary classifier. Got LinearRegression" with pytest.raises(ValueError, match=msg): CalibrationDisplay.from_estimator(reg, X, y) From 14057975c9e77c0f4780a8deb15ec6a5ec30fc8b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 Mar 2023 16:36:17 +0100 Subject: [PATCH 14/15] TST cover new test --- sklearn/utils/_response.py | 1 - sklearn/utils/tests/test_response.py | 20 ++++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index a1c8af33f7cb3..50b9409c8276d 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -116,7 +116,6 @@ def _get_response_values( "should be 'predict'. Got a regressor with response_method=" f"{response_method} instead." ) - raise ValueError(f"{estimator.__class__.__name__} should be a classifier") y_pred, pos_label = estimator.predict(X), None return y_pred, pos_label diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index 9d6d90ddd94ae..0e2ce5fe5f038 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -6,6 +6,7 @@ LinearRegression, LogisticRegression, ) +from sklearn.svm import SVC from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils._mocking import _MockEstimatorOnOffPrediction from sklearn.utils._testing import assert_allclose, assert_array_equal @@ -28,6 +29,25 @@ def test_get_response_values_regressor_error(response_method): _get_response_values(my_estimator, X, response_method=response_method) +@pytest.mark.parametrize( + "estimator, response_method", + [ + (DecisionTreeClassifier(), "predict_proba"), + (SVC(), "decision_function"), + ], +) +def test_get_response_values_error_multiclass_classifier(estimator, response_method): + """Check that we raise an error with multiclass classifier and requesting + response values different from `predict`.""" + X, y = make_classification( + n_samples=10, n_clusters_per_class=1, n_classes=3, random_state=0 + ) + classifier = estimator.fit(X, y) + err_msg = "With a multiclass estimator, the response method should be predict" + with pytest.raises(ValueError, match=err_msg): + _get_response_values(classifier, X, response_method=response_method) + + def test_get_response_values_regressor(): """Check the behaviour of `_get_response_values` with regressor.""" X, y = make_regression(n_samples=10, random_state=0) From 14b17eb10db0dcb4ad70f574a0df994fc0ba6c03 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 24 Mar 2023 16:39:53 +0100 Subject: [PATCH 15/15] TST more coverage --- sklearn/utils/tests/test_mocking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/utils/tests/test_mocking.py b/sklearn/utils/tests/test_mocking.py index 8cdac2e6a53f3..718c62d5cc83b 100644 --- a/sklearn/utils/tests/test_mocking.py +++ b/sklearn/utils/tests/test_mocking.py @@ -207,5 +207,6 @@ def test_mock_estimator_on_off_prediction(iris, response_methods): for response in possible_responses: if response in response_methods: assert hasattr(estimator, response) + assert getattr(estimator, response)(X) == response else: assert not hasattr(estimator, response)