diff --git a/doc/whats_new/upcoming_changes/sklearn.ensemble/31165.fix.rst b/doc/whats_new/upcoming_changes/sklearn.ensemble/31165.fix.rst new file mode 100644 index 0000000000000..22a11f0d276d8 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.ensemble/31165.fix.rst @@ -0,0 +1,4 @@ +- :class:`ensemble.BaggingClassfier`, :class:`ensemble.BaggingRegressor` + and :class:`ensemble.IsolationForest` now use `sample_weight` to draw + the samples instead of forwarding them to the underlying estimators. + By :user:`Antoine Baker `. diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index 34b613b15281a..9efd1540f15fe 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -22,7 +22,7 @@ column_or_1d, ) from ..utils._mask import indices_to_mask -from ..utils._param_validation import HasMethods, Interval, RealNotInt +from ..utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions from ..utils._tags import get_tags from ..utils.metadata_routing import ( MetadataRouter, @@ -72,6 +72,7 @@ def _generate_bagging_indices( n_samples, max_features, max_samples, + sample_weight, ): """Randomly draw feature and sample indices.""" # Get valid random state @@ -81,21 +82,47 @@ def _generate_bagging_indices( feature_indices = _generate_indices( random_state, bootstrap_features, n_features, max_features ) - sample_indices = _generate_indices( - random_state, bootstrap_samples, n_samples, max_samples - ) - + if sample_weight is None: + sample_indices = _generate_indices( + random_state, bootstrap_samples, n_samples, max_samples + ) + else: + normalized_sample_weight = sample_weight / np.sum(sample_weight) + sample_indices = random_state.choice( + n_samples, + max_samples, + replace=bootstrap_samples, + p=normalized_sample_weight, + ) return feature_indices, sample_indices +def _consumes_sample_weight(estimator): + # TODO(SLEP6): remove if condition for unrouted sample_weight when metadata + # routing can't be disabled. + # 1. If routing is enabled, we will check if the routing supports sample + # weight and use it if it does. + # 2. If routing is not enabled, we will check if the base + # estimator supports sample_weight and use it if it does. + support_sample_weight = has_fit_parameter(estimator, "sample_weight") + if _routing_enabled(): + request_or_router = get_routing_for_object(estimator) + consumes_sample_weight = request_or_router.consumes("fit", ("sample_weight",)) + else: + consumes_sample_weight = support_sample_weight + return consumes_sample_weight + + def _parallel_build_estimators( n_estimators, ensemble, X, y, + sample_weight, seeds, total_n_estimators, verbose, + sampling_strategy, check_input, fit_params, ): @@ -113,17 +140,6 @@ def _parallel_build_estimators( estimators = [] estimators_features = [] - # TODO: (slep6) remove if condition for unrouted sample_weight when metadata - # routing can't be disabled. - support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight") - if not _routing_enabled() and ( - not support_sample_weight and fit_params.get("sample_weight") is not None - ): - raise ValueError( - "The base estimator doesn't support sample weight, but sample_weight is " - "passed to the fit method." - ) - for i in range(n_estimators): if verbose > 1: print( @@ -140,6 +156,9 @@ def _parallel_build_estimators( estimator_fit = estimator.fit # Draw random feature, sample indices + sample_weight_in_indices = ( + sample_weight if sampling_strategy == "indexing" else None + ) features, indices = _generate_bagging_indices( random_state, bootstrap_features, @@ -148,32 +167,15 @@ def _parallel_build_estimators( n_samples, max_features, max_samples, + sample_weight_in_indices, ) fit_params_ = fit_params.copy() - - # TODO(SLEP6): remove if condition for unrouted sample_weight when metadata - # routing can't be disabled. - # 1. If routing is enabled, we will check if the routing supports sample - # weight and use it if it does. - # 2. If routing is not enabled, we will check if the base - # estimator supports sample_weight and use it if it does. - # Note: Row sampling can be achieved either through setting sample_weight or - # by indexing. The former is more efficient. Therefore, use this method - # if possible, otherwise use indexing. - if _routing_enabled(): - request_or_router = get_routing_for_object(ensemble.estimator_) - consumes_sample_weight = request_or_router.consumes( - "fit", ("sample_weight",) - ) - else: - consumes_sample_weight = support_sample_weight - if consumes_sample_weight: + # by indexing, controled by the `sampling_strategy` argument. + if sampling_strategy == "weighting": # Draw sub samples, using sample weights, and then fit - curr_sample_weight = _check_sample_weight( - fit_params_.pop("sample_weight", None), X - ).copy() + curr_sample_weight = _check_sample_weight(sample_weight, X).copy() if bootstrap: sample_counts = np.bincount(indices, minlength=n_samples) @@ -185,7 +187,7 @@ def _parallel_build_estimators( fit_params_["sample_weight"] = curr_sample_weight X_ = X[:, features] if requires_feature_indexing else X estimator_fit(X_, y, **fit_params_) - else: + elif sampling_strategy == "indexing": # cannot use sample_weight, so use indexing y_ = _safe_indexing(y, indices) X_ = _safe_indexing(X, indices) @@ -193,6 +195,8 @@ def _parallel_build_estimators( if requires_feature_indexing: X_ = X_[:, features] estimator_fit(X_, y_, **fit_params_) + else: + raise ValueError(f"{sampling_strategy=} must be 'indexing' or 'weighting'.") estimators.append(estimator) estimators_features.append(features) @@ -304,6 +308,7 @@ class BaseBagging(BaseEnsemble, metaclass=ABCMeta): "n_jobs": [None, Integral], "random_state": ["random_state"], "verbose": ["verbose"], + "sampling_strategy": [StrOptions({"auto", "indexing", "weighting"})], } @abstractmethod @@ -321,6 +326,7 @@ def __init__( n_jobs=None, random_state=None, verbose=0, + sampling_strategy="auto", ): super().__init__( estimator=estimator, @@ -335,6 +341,7 @@ def __init__( self.n_jobs = n_jobs self.random_state = random_state self.verbose = verbose + self.sampling_strategy = sampling_strategy @_fit_context( # BaseBagging.estimator is not validated yet @@ -355,8 +362,16 @@ def fit(self, X, y, sample_weight=None, **fit_params): sample_weight : array-like of shape (n_samples,), default=None Sample weights. If None, then samples are equally weighted. - Note that this is supported only if the base estimator supports - sample weighting. + Used as probabilities to draw the samples that are used to fit the + sub-estimators. It is strongly recommended to use bootstrap=True + (draw with replacement) for statistical soundness. + + .. versionchanged:: 1.7 + + The sample weights are used to draw the samples and are no + longer forwarded to the underlying estimators. It is now okay + to use a base estimator that does not support sample weight. + **fit_params : dict Parameters to pass to the underlying estimators. @@ -386,11 +401,25 @@ def fit(self, X, y, sample_weight=None, **fit_params): multi_output=True, ) + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X, dtype=None) + if sample_weight is not None and not self.bootstrap: + warn( + f"When fitting {self.__class__.__name__} with sample_weight " + f"it is recommended to use bootstrap=True, got {self.bootstrap}." + ) + if sample_weight is not None and self.sampling_strategy == "weighting": + warn( + f"When fitting {self.__class__.__name__} with sample_weight " + "it is recommended to use sampling_strategy='indexing' or 'auto', " + f"got {self.sampling_strategy}." + ) + return self._fit( X, y, - max_samples=self.max_samples, sample_weight=sample_weight, + max_samples=self.max_samples, **fit_params, ) @@ -401,10 +430,10 @@ def _fit( self, X, y, + sample_weight=None, max_samples=None, max_depth=None, check_input=True, - sample_weight=None, **fit_params, ): """Build a Bagging ensemble of estimators from the training @@ -420,6 +449,10 @@ def _fit( The target values (class labels in classification, real numbers in regression). + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. If None, then samples are equally weighted. + Used as probabilities to draw the samples. + max_samples : int or float, default=None Argument to use instead of self.max_samples. @@ -433,11 +466,6 @@ def _fit( If the meta-estimator already checks the input, set this value to False to prevent redundant input validation. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. - Note that this is supported only if the base estimator supports - sample weighting. - **fit_params : dict, default=None Parameters to pass to the :term:`fit` method of the underlying estimator. @@ -457,18 +485,11 @@ def _fit( # Check parameters self._validate_estimator(self._get_estimator()) - if sample_weight is not None: - fit_params["sample_weight"] = sample_weight - if _routing_enabled(): routed_params = process_routing(self, "fit", **fit_params) else: routed_params = Bunch() routed_params.estimator = Bunch(fit=fit_params) - if "sample_weight" in fit_params: - routed_params.estimator.fit["sample_weight"] = fit_params[ - "sample_weight" - ] if max_depth is not None: self.estimator_.max_depth = max_depth @@ -499,6 +520,28 @@ def _fit( # Store validated integer feature sampling value self._max_features = max_features + # Validate sampling_strategy + consumes_sample_weight = _consumes_sample_weight(self.estimator_) + sampling_strategy = self.sampling_strategy + if sampling_strategy == "auto": + if consumes_sample_weight and sample_weight is None: + sampling_strategy = "weighting" + else: + sampling_strategy = "indexing" + if (sampling_strategy == "weighting") and not consumes_sample_weight: + raise ValueError( + "The base estimator doesn't support sample weight, " + "but sample_weight is passed to the fit method." + ) + + # Store validated sampling_strategy + self._sampling_strategy = sampling_strategy + + # Store sample_weight_in_indices for _get_estimators_indices + self._sample_weight_in_indices = ( + sample_weight if sampling_strategy == "indexing" else None + ) + # Other checks if not self.bootstrap and self.oob_score: raise ValueError("Out of bag estimation only available if bootstrap=True") @@ -526,7 +569,7 @@ def _fit( elif n_more_estimators == 0: warn( "Warm-start fitting without increasing n_estimators does not " - "fit new trees." + "fit new estimators." ) return self @@ -552,9 +595,11 @@ def _fit( self, X, y, + sample_weight, seeds[starts[i] : starts[i + 1]], total_n_estimators, verbose=self.verbose, + sampling_strategy=sampling_strategy, check_input=check_input, fit_params=routed_params.estimator.fit, ) @@ -596,6 +641,7 @@ def _get_estimators_indices(self): self._n_samples, self._max_features, self._max_samples, + self._sample_weight_in_indices, ) yield feature_indices, sample_indices @@ -738,7 +784,9 @@ class BaggingClassifier(ClassifierMixin, BaseBagging): bootstrap : bool, default=True Whether samples are drawn with replacement. If False, sampling - without replacement is performed. + without replacement is performed. When fitting with `sample_weight` + it is strongly recommended to use bootstrap=True for statistical + soundness. bootstrap_features : bool, default=False Whether features are drawn with replacement. @@ -762,7 +810,7 @@ class BaggingClassifier(ClassifierMixin, BaseBagging): processors. See :term:`Glossary ` for more details. random_state : int, RandomState instance or None, default=None - Controls the random resampling of the original dataset + Controls the random sampling_strategy of the original dataset (sample wise and feature wise). If the base estimator accepts a `random_state` attribute, a different seed is generated for each instance in the ensemble. @@ -772,6 +820,26 @@ class BaggingClassifier(ClassifierMixin, BaseBagging): verbose : int, default=0 Controls the verbosity when fitting and predicting. + sampling_strategy : {'auto', 'indexing', 'weighting'}, default='auto' + How to handle the samples drawn from the original dataset. + + - 'indexing' explicitly indexes the original dataset. On the downside, + it creates copies and has therefore a memory overhead. On the upside, + it does not require the base estimator to support `sample_weight`. + - 'weighting' do not index the original dataset and is + therefore more memory efficient. Instead it passes the selected indices + as `sample_weight` to the base estimator, which must therefore + support `sample_weight`. + - 'auto' will select 'indexing' if fitting the bagging estimator with + `sample_weight` or if the base estimator does not support `sample_weight`, + and `weighting` otherwise. + + .. warning:: + Only the 'indexing' option along with boostrap=True gives statistically + correct results when fitting the bagging estimator with `sample_weight`. + + .. versionadded:: 1.8 + Attributes ---------- estimator_ : estimator @@ -866,6 +934,7 @@ def __init__( n_jobs=None, random_state=None, verbose=0, + sampling_strategy="auto", ): super().__init__( estimator=estimator, @@ -879,6 +948,7 @@ def __init__( n_jobs=n_jobs, random_state=random_state, verbose=verbose, + sampling_strategy=sampling_strategy, ) def _get_estimator(self): @@ -1246,7 +1316,9 @@ class BaggingRegressor(RegressorMixin, BaseBagging): bootstrap : bool, default=True Whether samples are drawn with replacement. If False, sampling - without replacement is performed. + without replacement is performed. When fitting with `sample_weight` + it is strongly recommended to use bootstrap=True for statistical + soundness. bootstrap_features : bool, default=False Whether features are drawn with replacement. @@ -1267,7 +1339,7 @@ class BaggingRegressor(RegressorMixin, BaseBagging): processors. See :term:`Glossary ` for more details. random_state : int, RandomState instance or None, default=None - Controls the random resampling of the original dataset + Controls the random sampling_strategy of the original dataset (sample wise and feature wise). If the base estimator accepts a `random_state` attribute, a different seed is generated for each instance in the ensemble. @@ -1277,6 +1349,26 @@ class BaggingRegressor(RegressorMixin, BaseBagging): verbose : int, default=0 Controls the verbosity when fitting and predicting. + sampling_strategy : {'auto', 'indexing', 'weighting'}, default='auto' + How to handle the samples drawn from the original dataset. + + - 'indexing' explicitly indexes the original dataset. On the downside, + it creates copies and has therefore a memory overhead. On the upside, + it does not require the base estimator to support `sample_weight`. + - 'weighting' do not index the original dataset and is + therefore more memory efficient. Instead it passes the selected indices + as `sample_weight` to the base estimator, which must therefore + support `sample_weight`. + - 'auto' will select 'indexing' if fitting the bagging estimator with + `sample_weight` or if the base estimator does not support `sample_weight`, + and `weighting` otherwise. + + .. warning:: + Only the 'indexing' option along with boostrap=True gives statistically + correct results when fitting the bagging estimator with `sample_weight`. + + .. versionadded:: 1.8 + Attributes ---------- estimator_ : estimator @@ -1348,7 +1440,7 @@ class BaggingRegressor(RegressorMixin, BaseBagging): >>> regr = BaggingRegressor(estimator=SVR(), ... n_estimators=10, random_state=0).fit(X, y) >>> regr.predict([[0, 0, 0, 0]]) - array([-2.8720]) + array([-2.8948]) """ def __init__( @@ -1365,6 +1457,7 @@ def __init__( n_jobs=None, random_state=None, verbose=0, + sampling_strategy="auto", ): super().__init__( estimator=estimator, @@ -1378,6 +1471,7 @@ def __init__( n_jobs=n_jobs, random_state=random_state, verbose=verbose, + sampling_strategy=sampling_strategy, ) def predict(self, X, **params): diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py index 4e5287af7f699..5fed54c3b6430 100644 --- a/sklearn/ensemble/_iforest.py +++ b/sklearn/ensemble/_iforest.py @@ -20,7 +20,12 @@ from ..utils._chunking import get_chunk_n_rows from ..utils._param_validation import Interval, RealNotInt, StrOptions from ..utils.parallel import Parallel, delayed -from ..utils.validation import _num_samples, check_is_fitted, validate_data +from ..utils.validation import ( + _check_sample_weight, + _num_samples, + check_is_fitted, + validate_data, +) from ._bagging import BaseBagging __all__ = ["IsolationForest"] @@ -317,6 +322,10 @@ def fit(self, X, y=None, sample_weight=None): X = validate_data( self, X, accept_sparse=["csc"], dtype=tree_dtype, ensure_all_finite=False ) + + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X, dtype=None) + if issparse(X): # Pre-sort indices to avoid that each individual tree of the # ensemble sorts the indices. @@ -350,9 +359,9 @@ def fit(self, X, y=None, sample_weight=None): super()._fit( X, y, - max_samples, - max_depth=max_depth, sample_weight=sample_weight, + max_samples=max_samples, + max_depth=max_depth, check_input=False, ) diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 2cb9336bfd759..6fc11e0970a04 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -589,28 +589,6 @@ def test_bagging_with_pipeline(): assert isinstance(estimator[0].steps[-1][1].random_state, int) -class DummyZeroEstimator(BaseEstimator): - def fit(self, X, y): - self.classes_ = np.unique(y) - return self - - def predict(self, X): - return self.classes_[np.zeros(X.shape[0], dtype=int)] - - -def test_bagging_sample_weight_unsupported_but_passed(): - estimator = BaggingClassifier(DummyZeroEstimator()) - rng = check_random_state(0) - - estimator.fit(iris.data, iris.target).predict(iris.data) - with pytest.raises(ValueError): - estimator.fit( - iris.data, - iris.target, - sample_weight=rng.randint(10, size=(iris.data.shape[0])), - ) - - def test_warm_start(random_state=42): # Test if fitting incrementally with warm start gives a forest of the # right size and the same results as a normal fit. @@ -909,8 +887,9 @@ def test_bagging_classifier_with_missing_inputs(): def test_bagging_small_max_features(): # Check that Bagging estimator can accept low fractional max_features - X = np.array([[1, 2], [3, 4]]) - y = np.array([1, 0]) + rng = np.random.RandomState(42) + X = rng.randn(10, 2) + y = rng.randint(2, size=X.shape[0]) bagging = BaggingClassifier(LogisticRegression(), max_features=0.3, random_state=1) bagging.fit(X, y) diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index f4ed228ec2f9d..2120c8a0c51f6 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -330,7 +330,7 @@ "y": y, "preserves_metadata": False, "estimator_routing_methods": [ - "fit", + ("fit", ["metadata"]), "predict", "predict_proba", "predict_log_proba", @@ -349,7 +349,7 @@ "X": X, "y": y, "preserves_metadata": False, - "estimator_routing_methods": ["fit", "predict"], + "estimator_routing_methods": [("fit", ["metadata"]), "predict"], }, { "metaestimator": RidgeCV, @@ -459,7 +459,13 @@ - X: X-data to fit and predict - y: y-data to fit - estimator_routing_methods: list of all methods to check for routing metadata - to the sub-estimator + to the sub-estimator. Each value is either a str or a tuple: + - str: the name of the method, all metadata in this method must be routed to the + sub-estimator + - tuple: the name of the method, the second element is a list of metadata keys + to be passed to the sub-estimator. This is useful if certain metadata such as + `sample_weight` are never routed and only consumed, such as in `BaggingClassifier` + and `BaggingRegressor`. - preserves_metadata: - True (default): the metaestimator passes the metadata to the sub-estimator without modification. We check that the values recorded by @@ -562,6 +568,32 @@ def get_init_args(metaestimator_info, sub_estimator_consumes): ) +def filter_metadata_in_routing_methods(estimator_routing_methods): + """Process estimator_routing_methods and return a dict. + + Parameters + ---------- + estimator_routing_methods : list of str or tuple + The estimator_routing_methods info from METAESTIMATORS. + + Returns + ------- + routing_methods : dict + The dictionary is of the form {"method": ["metadata", ...]}. + It specifies the list of metadata keys for each routing method. + By default the list includes `sample_weight` and `metadata`. + """ + res = dict() + for method_spec in estimator_routing_methods: + if isinstance(method_spec, str): + method = method_spec + metadata = ["sample_weight", "metadata"] + else: + method, metadata = method_spec + res[method] = metadata + return res + + def set_requests(obj, *, method_mapping, methods, metadata_name, value=True): """Call `set_{method}_request` on a list of methods from the sub-estimator. @@ -662,10 +694,12 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator): metaestimator_class = metaestimator["metaestimator"] X = metaestimator["X"] y = metaestimator["y"] - routing_methods = metaestimator["estimator_routing_methods"] + routing_methods = filter_metadata_in_routing_methods( + metaestimator["estimator_routing_methods"] + ) - for method_name in routing_methods: - for key in ["sample_weight", "metadata"]: + for method_name, metadata_keys in routing_methods.items(): + for key in metadata_keys: kwargs, (estimator, _), (scorer, _), *_ = get_init_args( metaestimator, sub_estimator_consumes=True ) @@ -721,12 +755,14 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator): metaestimator_class = metaestimator["metaestimator"] X = metaestimator["X"] y = metaestimator["y"] - routing_methods = metaestimator["estimator_routing_methods"] + routing_methods = filter_metadata_in_routing_methods( + metaestimator["estimator_routing_methods"] + ) method_mapping = metaestimator.get("method_mapping", {}) preserves_metadata = metaestimator.get("preserves_metadata", True) - for method_name in routing_methods: - for key in ["sample_weight", "metadata"]: + for method_name, metadata_keys in routing_methods.items(): + for key in metadata_keys: val = {"sample_weight": sample_weight, "metadata": metadata}[key] method_kwargs = {key: val} @@ -797,8 +833,9 @@ def set_request(estimator, method_name): metaestimator_class = metaestimator["metaestimator"] X = metaestimator["X"] y = metaestimator["y"] - routing_methods = metaestimator["estimator_routing_methods"] - + routing_methods = filter_metadata_in_routing_methods( + metaestimator["estimator_routing_methods"] + ) for method_name in routing_methods: kwargs, (estimator, _), (_, _), (_, _) = get_init_args( metaestimator, sub_estimator_consumes=False