diff --git a/doc/whats_new/upcoming_changes/metadata-routing/30833.feature.rst b/doc/whats_new/upcoming_changes/metadata-routing/30833.feature.rst new file mode 100644 index 0000000000000..e46420e9ee2d2 --- /dev/null +++ b/doc/whats_new/upcoming_changes/metadata-routing/30833.feature.rst @@ -0,0 +1,4 @@ +- :class:`ensemble.BaggingClassifier` and :class:`ensemble.BaggingRegressor` now support + metadata routing through their `predict`, `predict_proba`, `predict_log_proba` and + `decision_function` methods and pass `**params` to the underlying estimators. + By :user:`Stefanie Senger `. diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index 20013e1f6d000..901c63c9250bc 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -202,14 +202,23 @@ def _parallel_build_estimators( return estimators, estimators_features -def _parallel_predict_proba(estimators, estimators_features, X, n_classes): +def _parallel_predict_proba( + estimators, + estimators_features, + X, + n_classes, + predict_params=None, + predict_proba_params=None, +): """Private function used to compute (proba-)predictions within a job.""" n_samples = X.shape[0] proba = np.zeros((n_samples, n_classes)) for estimator, features in zip(estimators, estimators_features): if hasattr(estimator, "predict_proba"): - proba_estimator = estimator.predict_proba(X[:, features]) + proba_estimator = estimator.predict_proba( + X[:, features], **(predict_params or {}) + ) if n_classes == len(estimator.classes_): proba += proba_estimator @@ -221,7 +230,9 @@ def _parallel_predict_proba(estimators, estimators_features, X, n_classes): else: # Resort to voting - predictions = estimator.predict(X[:, features]) + predictions = estimator.predict( + X[:, features], **(predict_proba_params or {}) + ) for i in range(n_samples): proba[i, predictions[i]] += 1 @@ -229,7 +240,7 @@ def _parallel_predict_proba(estimators, estimators_features, X, n_classes): return proba -def _parallel_predict_log_proba(estimators, estimators_features, X, n_classes): +def _parallel_predict_log_proba(estimators, estimators_features, X, n_classes, params): """Private function used to compute log probabilities within a job.""" n_samples = X.shape[0] log_proba = np.empty((n_samples, n_classes)) @@ -237,7 +248,7 @@ def _parallel_predict_log_proba(estimators, estimators_features, X, n_classes): all_classes = np.arange(n_classes, dtype=int) for estimator, features in zip(estimators, estimators_features): - log_proba_estimator = estimator.predict_log_proba(X[:, features]) + log_proba_estimator = estimator.predict_log_proba(X[:, features], **params) if n_classes == len(estimator.classes_): log_proba = np.logaddexp(log_proba, log_proba_estimator) @@ -254,18 +265,18 @@ def _parallel_predict_log_proba(estimators, estimators_features, X, n_classes): return log_proba -def _parallel_decision_function(estimators, estimators_features, X): +def _parallel_decision_function(estimators, estimators_features, X, params): """Private function used to compute decisions within a job.""" return sum( - estimator.decision_function(X[:, features]) + estimator.decision_function(X[:, features], **params) for estimator, features in zip(estimators, estimators_features) ) -def _parallel_predict_regression(estimators, estimators_features, X): +def _parallel_predict_regression(estimators, estimators_features, X, params): """Private function used to compute predictions within a job.""" return sum( - estimator.predict(X[:, features]) + estimator.predict(X[:, features], **params) for estimator, features in zip(estimators, estimators_features) ) @@ -615,10 +626,47 @@ def get_metadata_routing(self): routing information. """ router = MetadataRouter(owner=self.__class__.__name__) - router.add( - estimator=self._get_estimator(), - method_mapping=MethodMapping().add(callee="fit", caller="fit"), + + method_mapping = MethodMapping() + method_mapping.add(caller="fit", callee="fit").add( + caller="decision_function", callee="decision_function" ) + + # the router needs to be built depending on whether the sub-estimator has a + # `predict_proba` method (as BaggingClassifier decides dynamically at runtime): + if hasattr(self._get_estimator(), "predict_proba"): + ( + method_mapping.add(caller="predict", callee="predict_proba").add( + caller="predict_proba", callee="predict_proba" + ) + ) + + else: + ( + method_mapping.add(caller="predict", callee="predict").add( + caller="predict_proba", callee="predict" + ) + ) + + # the router needs to be built depending on whether the sub-estimator has a + # `predict_log_proba` method (as BaggingClassifier decides dynamically at + # runtime): + if hasattr(self._get_estimator(), "predict_log_proba"): + method_mapping.add(caller="predict_log_proba", callee="predict_log_proba") + + else: + # if `predict_log_proba` is not available in BaggingClassifier's + # sub-estimator, the routing should go to its `predict_proba` if it is + # available or else to its `predict` method; according to how + # `sample_weight` is passed to the respective methods dynamically at + # runtime: + if hasattr(self._get_estimator(), "predict_proba"): + method_mapping.add(caller="predict_log_proba", callee="predict_proba") + + else: + method_mapping.add(caller="predict_log_proba", callee="predict") + + router.add(estimator=self._get_estimator(), method_mapping=method_mapping) return router @abstractmethod @@ -882,7 +930,7 @@ def _validate_y(self, y): return y - def predict(self, X): + def predict(self, X, **params): """Predict class for X. The predicted class of an input sample is computed as the class with @@ -895,15 +943,28 @@ def predict(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + **params : dict + Parameters routed to the `predict_proba` (if available) or the `predict` + method (otherwise) of the sub-estimators via the metadata routing API. + + .. versionadded:: 1.7 + + Only available if + `sklearn.set_config(enable_metadata_routing=True)` is set. See + :ref:`Metadata Routing User Guide ` for more + details. + Returns ------- y : ndarray of shape (n_samples,) The predicted classes. """ - predicted_probabilitiy = self.predict_proba(X) + _raise_for_params(params, self, "predict") + + predicted_probabilitiy = self.predict_proba(X, **params) return self.classes_.take((np.argmax(predicted_probabilitiy, axis=1)), axis=0) - def predict_proba(self, X): + def predict_proba(self, X, **params): """Predict class probabilities for X. The predicted class probabilities of an input sample is computed as @@ -919,12 +980,25 @@ def predict_proba(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + **params : dict + Parameters routed to the `predict_proba` (if available) or the `predict` + method (otherwise) of the sub-estimators via the metadata routing API. + + .. versionadded:: 1.7 + + Only available if + `sklearn.set_config(enable_metadata_routing=True)` is set. See + :ref:`Metadata Routing User Guide ` for more + details. + Returns ------- p : ndarray of shape (n_samples, n_classes) The class probabilities of the input samples. The order of the classes corresponds to that in the attribute :term:`classes_`. """ + _raise_for_params(params, self, "predict_proba") + check_is_fitted(self) # Check data X = validate_data( @@ -936,6 +1010,12 @@ def predict_proba(self, X): reset=False, ) + if _routing_enabled(): + routed_params = process_routing(self, "predict_proba", **params) + else: + routed_params = Bunch() + routed_params.estimator = Bunch(predict_proba=Bunch()) + # Parallel loop n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) @@ -947,6 +1027,8 @@ def predict_proba(self, X): self.estimators_features_[starts[i] : starts[i + 1]], X, self.n_classes_, + predict_params=routed_params.estimator.get("predict", None), + predict_proba_params=routed_params.estimator.get("predict_proba", None), ) for i in range(n_jobs) ) @@ -956,7 +1038,7 @@ def predict_proba(self, X): return proba - def predict_log_proba(self, X): + def predict_log_proba(self, X, **params): """Predict class log-probabilities for X. The predicted class log-probabilities of an input sample is computed as @@ -969,13 +1051,29 @@ def predict_log_proba(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + **params : dict + Parameters routed to the `predict_log_proba`, the `predict_proba` or the + `proba` method of the sub-estimators via the metadata routing API. The + routing is tried in the mentioned order depending on whether this method is + available on the sub-estimator. + + .. versionadded:: 1.7 + + Only available if + `sklearn.set_config(enable_metadata_routing=True)` is set. See + :ref:`Metadata Routing User Guide ` for more + details. + Returns ------- p : ndarray of shape (n_samples, n_classes) The class log-probabilities of the input samples. The order of the classes corresponds to that in the attribute :term:`classes_`. """ + _raise_for_params(params, self, "predict_log_proba") + check_is_fitted(self) + if hasattr(self.estimator_, "predict_log_proba"): # Check data X = validate_data( @@ -987,6 +1085,12 @@ def predict_log_proba(self, X): reset=False, ) + if _routing_enabled(): + routed_params = process_routing(self, "predict_log_proba", **params) + else: + routed_params = Bunch() + routed_params.estimator = Bunch(predict_log_proba=Bunch()) + # Parallel loop n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) @@ -996,6 +1100,7 @@ def predict_log_proba(self, X): self.estimators_features_[starts[i] : starts[i + 1]], X, self.n_classes_, + params=routed_params.estimator.predict_log_proba, ) for i in range(n_jobs) ) @@ -1009,14 +1114,14 @@ def predict_log_proba(self, X): log_proba -= np.log(self.n_estimators) else: - log_proba = np.log(self.predict_proba(X)) + log_proba = np.log(self.predict_proba(X, **params)) return log_proba @available_if( _estimator_has("decision_function", delegates=("estimators_", "estimator")) ) - def decision_function(self, X): + def decision_function(self, X, **params): """Average of the decision functions of the base classifiers. Parameters @@ -1025,6 +1130,17 @@ def decision_function(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + **params : dict + Parameters routed to the `decision_function` method of the sub-estimators + via the metadata routing API. + + .. versionadded:: 1.7 + + Only available if + `sklearn.set_config(enable_metadata_routing=True)` is set. See + :ref:`Metadata Routing User Guide ` for more + details. + Returns ------- score : ndarray of shape (n_samples, k) @@ -1033,6 +1149,8 @@ def decision_function(self, X): ``classes_``. Regression and binary classification are special cases with ``k == 1``, otherwise ``k==n_classes``. """ + _raise_for_params(params, self, "decision_function") + check_is_fitted(self) # Check data @@ -1045,6 +1163,12 @@ def decision_function(self, X): reset=False, ) + if _routing_enabled(): + routed_params = process_routing(self, "decision_function", **params) + else: + routed_params = Bunch() + routed_params.estimator = Bunch(decision_function=Bunch()) + # Parallel loop n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) @@ -1053,6 +1177,7 @@ def decision_function(self, X): self.estimators_[starts[i] : starts[i + 1]], self.estimators_features_[starts[i] : starts[i + 1]], X, + params=routed_params.estimator.decision_function, ) for i in range(n_jobs) ) @@ -1251,7 +1376,7 @@ def __init__( verbose=verbose, ) - def predict(self, X): + def predict(self, X, **params): """Predict regression target for X. The predicted regression target of an input sample is computed as the @@ -1263,11 +1388,24 @@ def predict(self, X): The training input samples. Sparse matrices are accepted only if they are supported by the base estimator. + **params : dict + Parameters routed to the `predict` method of the sub-estimators via the + metadata routing API. + + .. versionadded:: 1.7 + + Only available if + `sklearn.set_config(enable_metadata_routing=True)` is set. See + :ref:`Metadata Routing User Guide ` for more + details. + Returns ------- y : ndarray of shape (n_samples,) The predicted values. """ + _raise_for_params(params, self, "predict") + check_is_fitted(self) # Check data X = validate_data( @@ -1279,6 +1417,12 @@ def predict(self, X): reset=False, ) + if _routing_enabled(): + routed_params = process_routing(self, "predict", **params) + else: + routed_params = Bunch() + routed_params.estimator = Bunch(predict=Bunch()) + # Parallel loop n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) @@ -1287,6 +1431,7 @@ def predict(self, X): self.estimators_[starts[i] : starts[i + 1]], self.estimators_features_[starts[i] : starts[i + 1]], X, + params=routed_params.estimator.predict, ) for i in range(n_jobs) ) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index f5386804d77d7..d8b1ce9091043 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -11,7 +11,7 @@ import numpy as np import pytest -import sklearn +from sklearn import config_context from sklearn.base import BaseEstimator from sklearn.datasets import load_diabetes, load_iris, make_hastie_10_2 from sklearn.dummy import DummyClassifier, DummyRegressor @@ -33,6 +33,13 @@ from sklearn.preprocessing import FunctionTransformer, scale from sklearn.random_projection import SparseRandomProjection from sklearn.svm import SVC, SVR +from sklearn.tests.metadata_routing_common import ( + ConsumingClassifierWithOnlyPredict, + ConsumingClassifierWithoutPredictLogProba, + ConsumingClassifierWithoutPredictProba, + _Registry, + check_recorded_metadata, +) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils import check_random_state from sklearn.utils._testing import assert_array_almost_equal, assert_array_equal @@ -944,6 +951,11 @@ def test_bagging_allow_nan_tag(bagging, expected_allow_nan): assert bagging.__sklearn_tags__().input_tags.allow_nan == expected_allow_nan +# Metadata Routing Tests +# ====================== + + +@config_context(enable_metadata_routing=True) @pytest.mark.parametrize( "model", [ @@ -957,8 +969,62 @@ def test_bagging_allow_nan_tag(bagging, expected_allow_nan): ) def test_bagging_with_metadata_routing(model): """Make sure that metadata routing works with non-default estimator.""" - with sklearn.config_context(enable_metadata_routing=True): - model.fit(iris.data, iris.target) + model.fit(iris.data, iris.target) + + +@pytest.mark.parametrize( + "sub_estimator, caller, callee", + [ + (ConsumingClassifierWithoutPredictProba, "predict", "predict"), + ( + ConsumingClassifierWithoutPredictLogProba, + "predict_log_proba", + "predict_proba", + ), + (ConsumingClassifierWithOnlyPredict, "predict_log_proba", "predict"), + ], +) +@config_context(enable_metadata_routing=True) +def test_metadata_routing_with_dynamic_method_selection(sub_estimator, caller, callee): + """Test that metadata routing works in `BaggingClassifier` with dynamic selection of + the sub-estimator's methods. Here we test only specific test cases, where + sub-estimator methods are not present and are not tested with `ConsumingClassifier` + (which possesses all the methods) in + sklearn/tests/test_metaestimators_metadata_routing.py: `BaggingClassifier.predict()` + dynamically routes to `predict` if the sub-estimator doesn't have `predict_proba` + and `BaggingClassifier.predict_log_proba()` dynamically routes to `predict_proba` if + the sub-estimator doesn't have `predict_log_proba`, or to `predict`, if it doesn't + have it. + """ + X = np.array([[0, 2], [1, 4], [2, 6]]) + y = [1, 2, 3] + sample_weight, metadata = [1], "a" + registry = _Registry() + estimator = sub_estimator(registry=registry) + set_callee_request = "set_" + callee + "_request" + getattr(estimator, set_callee_request)(sample_weight=True, metadata=True) + + bagging = BaggingClassifier(estimator=estimator) + bagging.fit(X, y) + getattr(bagging, caller)( + X=np.array([[1, 1], [1, 3], [0, 2]]), + sample_weight=sample_weight, + metadata=metadata, + ) + + assert len(registry) + for estimator in registry: + check_recorded_metadata( + obj=estimator, + method=callee, + parent=caller, + sample_weight=sample_weight, + metadata=metadata, + ) + + +# End of Metadata Routing Tests +# ============================= @pytest.mark.parametrize( diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index 98503652df6f0..c4af13ef66344 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -218,9 +218,9 @@ def predict(self, X): def predict_proba(self, X): # dummy probabilities to support predict_proba - y_proba = np.empty(shape=(len(X), 2)) - y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) - y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) + y_proba = np.empty(shape=(len(X), len(self.classes_)), dtype=np.float32) + # each row sums up to 1.0: + y_proba[:] = np.random.dirichlet(alpha=np.ones(len(self.classes_)), size=len(X)) return y_proba def predict_log_proba(self, X): @@ -298,16 +298,16 @@ def predict_proba(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, sample_weight=sample_weight, metadata=metadata ) - y_proba = np.empty(shape=(len(X), 2)) - y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) - y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) + y_proba = np.empty(shape=(len(X), len(self.classes_)), dtype=np.float32) + # each row sums up to 1.0: + y_proba[:] = np.random.dirichlet(alpha=np.ones(len(self.classes_)), size=len(X)) return y_proba def predict_log_proba(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, sample_weight=sample_weight, metadata=metadata ) - return np.zeros(shape=(len(X), 2)) + return self.predict_proba(X) def decision_function(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( @@ -325,6 +325,46 @@ def score(self, X, y, sample_weight="default", metadata="default"): return 1 +class ConsumingClassifierWithoutPredictProba(ConsumingClassifier): + """ConsumingClassifier without a predict_proba method, but with predict_log_proba. + + Used to mimic dynamic method selection such as in the `_parallel_predict_proba()` + function called by `BaggingClassifier`. + """ + + @property + def predict_proba(self): + raise AttributeError("This estimator does not support predict_proba") + + +class ConsumingClassifierWithoutPredictLogProba(ConsumingClassifier): + """ConsumingClassifier without a predict_log_proba method, but with predict_proba. + + Used to mimic dynamic method selection such as in + `BaggingClassifier.predict_log_proba()`. + """ + + @property + def predict_log_proba(self): + raise AttributeError("This estimator does not support predict_log_proba") + + +class ConsumingClassifierWithOnlyPredict(ConsumingClassifier): + """ConsumingClassifier with only a predict method. + + Used to mimic dynamic method selection such as in + `BaggingClassifier.predict_log_proba()`. + """ + + @property + def predict_proba(self): + raise AttributeError("This estimator does not support predict_proba") + + @property + def predict_log_proba(self): + raise AttributeError("This estimator does not support predict_log_proba") + + class ConsumingTransformer(TransformerMixin, BaseEstimator): """A transformer which accepts metadata on fit and transform. diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index 6947c14ff5e59..ae2a186a3c5c2 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -329,7 +329,18 @@ "X": X, "y": y, "preserves_metadata": False, - "estimator_routing_methods": ["fit"], + "estimator_routing_methods": [ + "fit", + "predict", + "predict_proba", + "predict_log_proba", + "decision_function", + ], + "method_mapping": { + "predict": ["predict", "predict_proba"], + "predict_proba": ["predict", "predict_proba"], + "predict_log_proba": ["predict", "predict_proba", "predict_log_proba"], + }, }, { "metaestimator": BaggingRegressor, @@ -338,7 +349,7 @@ "X": X, "y": y, "preserves_metadata": False, - "estimator_routing_methods": ["fit"], + "estimator_routing_methods": ["fit", "predict"], }, { "metaestimator": RidgeCV,