diff --git a/doc/whats_new/upcoming_changes/sklearn.ensemble/31414.fix.rst b/doc/whats_new/upcoming_changes/sklearn.ensemble/31414.fix.rst new file mode 100644 index 0000000000000..6a881a3174850 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.ensemble/31414.fix.rst @@ -0,0 +1,7 @@ +- :class:`ensemble.BaggingClassfier`, :class:`ensemble.BaggingRegressor` + and :class:`ensemble.IsolationForest` now use `sample_weight` to draw + the samples instead of forwarding them multiplied by a uniformly sampled + mask to the underlying estimators. Furthermore, `max_samples` is now + interpreted as a fraction of `sample_weight.sum()` instead of `X.shape[0]` + when passed as a float. + By :user:`Antoine Baker `. diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py index 34b613b15281a..b727c7f233975 100644 --- a/sklearn/ensemble/_bagging.py +++ b/sklearn/ensemble/_bagging.py @@ -72,6 +72,7 @@ def _generate_bagging_indices( n_samples, max_features, max_samples, + sample_weight, ): """Randomly draw feature and sample indices.""" # Get valid random state @@ -81,18 +82,37 @@ def _generate_bagging_indices( feature_indices = _generate_indices( random_state, bootstrap_features, n_features, max_features ) - sample_indices = _generate_indices( - random_state, bootstrap_samples, n_samples, max_samples - ) + if sample_weight is None: + sample_indices = _generate_indices( + random_state, bootstrap_samples, n_samples, max_samples + ) + else: + normalized_sample_weight = sample_weight / np.sum(sample_weight) + sample_indices = random_state.choice( + n_samples, + max_samples, + replace=bootstrap_samples, + p=normalized_sample_weight, + ) return feature_indices, sample_indices +def _consumes_sample_weight(estimator): + if _routing_enabled(): + request_or_router = get_routing_for_object(estimator) + consumes_sample_weight = request_or_router.consumes("fit", ("sample_weight",)) + else: + consumes_sample_weight = has_fit_parameter(estimator, "sample_weight") + return consumes_sample_weight + + def _parallel_build_estimators( n_estimators, ensemble, X, y, + sample_weight, seeds, total_n_estimators, verbose, @@ -108,22 +128,12 @@ def _parallel_build_estimators( bootstrap_features = ensemble.bootstrap_features has_check_input = has_fit_parameter(ensemble.estimator_, "check_input") requires_feature_indexing = bootstrap_features or max_features != n_features + consumes_sample_weight = _consumes_sample_weight(ensemble.estimator_) # Build estimators estimators = [] estimators_features = [] - # TODO: (slep6) remove if condition for unrouted sample_weight when metadata - # routing can't be disabled. - support_sample_weight = has_fit_parameter(ensemble.estimator_, "sample_weight") - if not _routing_enabled() and ( - not support_sample_weight and fit_params.get("sample_weight") is not None - ): - raise ValueError( - "The base estimator doesn't support sample weight, but sample_weight is " - "passed to the fit method." - ) - for i in range(n_estimators): if verbose > 1: print( @@ -139,7 +149,8 @@ def _parallel_build_estimators( else: estimator_fit = estimator.fit - # Draw random feature, sample indices + # Draw random feature, sample indices (using normalized sample_weight + # as probabilites if provided). features, indices = _generate_bagging_indices( random_state, bootstrap_features, @@ -148,45 +159,22 @@ def _parallel_build_estimators( n_samples, max_features, max_samples, + sample_weight, ) fit_params_ = fit_params.copy() - # TODO(SLEP6): remove if condition for unrouted sample_weight when metadata - # routing can't be disabled. - # 1. If routing is enabled, we will check if the routing supports sample - # weight and use it if it does. - # 2. If routing is not enabled, we will check if the base - # estimator supports sample_weight and use it if it does. - # Note: Row sampling can be achieved either through setting sample_weight or - # by indexing. The former is more efficient. Therefore, use this method + # by indexing. The former is more memory efficient. Therefore, use this method # if possible, otherwise use indexing. - if _routing_enabled(): - request_or_router = get_routing_for_object(ensemble.estimator_) - consumes_sample_weight = request_or_router.consumes( - "fit", ("sample_weight",) - ) - else: - consumes_sample_weight = support_sample_weight if consumes_sample_weight: - # Draw sub samples, using sample weights, and then fit - curr_sample_weight = _check_sample_weight( - fit_params_.pop("sample_weight", None), X - ).copy() - - if bootstrap: - sample_counts = np.bincount(indices, minlength=n_samples) - curr_sample_weight *= sample_counts - else: - not_indices_mask = ~indices_to_mask(indices, n_samples) - curr_sample_weight[not_indices_mask] = 0 - - fit_params_["sample_weight"] = curr_sample_weight + # Row sampling by setting sample_weight + indices_as_sample_weight = np.bincount(indices, minlength=n_samples) + fit_params_["sample_weight"] = indices_as_sample_weight X_ = X[:, features] if requires_feature_indexing else X estimator_fit(X_, y, **fit_params_) else: - # cannot use sample_weight, so use indexing + # Row sampling by indexing y_ = _safe_indexing(y, indices) X_ = _safe_indexing(X, indices) fit_params_ = _check_method_params(X, params=fit_params_, indices=indices) @@ -354,9 +342,11 @@ def fit(self, X, y, sample_weight=None, **fit_params): regression). sample_weight : array-like of shape (n_samples,), default=None - Sample weights. If None, then samples are equally weighted. - Note that this is supported only if the base estimator supports - sample weighting. + Sample weights. If None, then samples are equally weighted. Used as + probabilities to sample the training set. Note that the expected + frequency semantics for the `sample_weight` parameter are only + fulfilled when sampling with replacement `bootstrap=True`. + **fit_params : dict Parameters to pass to the underlying estimators. @@ -386,6 +376,15 @@ def fit(self, X, y, sample_weight=None, **fit_params): multi_output=True, ) + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X, dtype=None) + + if not self.bootstrap: + warn( + f"When fitting {self.__class__.__name__} with sample_weight " + f"it is recommended to use bootstrap=True, got {self.bootstrap}." + ) + return self._fit( X, y, @@ -435,8 +434,6 @@ def _fit( sample_weight : array-like of shape (n_samples,), default=None Sample weights. If None, then samples are equally weighted. - Note that this is supported only if the base estimator supports - sample weighting. **fit_params : dict, default=None Parameters to pass to the :term:`fit` method of the underlying @@ -457,18 +454,11 @@ def _fit( # Check parameters self._validate_estimator(self._get_estimator()) - if sample_weight is not None: - fit_params["sample_weight"] = sample_weight - if _routing_enabled(): routed_params = process_routing(self, "fit", **fit_params) else: routed_params = Bunch() routed_params.estimator = Bunch(fit=fit_params) - if "sample_weight" in fit_params: - routed_params.estimator.fit["sample_weight"] = fit_params[ - "sample_weight" - ] if max_depth is not None: self.estimator_.max_depth = max_depth @@ -476,11 +466,26 @@ def _fit( # Validate max_samples if max_samples is None: max_samples = self.max_samples - elif not isinstance(max_samples, numbers.Integral): - max_samples = int(max_samples * X.shape[0]) - if max_samples > X.shape[0]: - raise ValueError("max_samples must be <= n_samples") + if not isinstance(max_samples, numbers.Integral): + if sample_weight is None: + max_samples = max(int(max_samples * X.shape[0]), 1) + else: + sw_sum = np.sum(sample_weight) + if sw_sum <= 1: + raise ValueError( + f"The total sum of sample weights is {sw_sum}, which prevents " + "resampling with a fractional value for max_samples=" + f"{max_samples}. Either pass max_samples as an integer or " + "use a larger sample_weight." + ) + max_samples = max(int(max_samples * sw_sum), 1) + + if not self.bootstrap and max_samples > X.shape[0]: + raise ValueError( + f"Effective max_samples={max_samples} must be <= n_samples=" + f"{X.shape[0]} to be able to sample without replacement." + ) # Store validated integer row sampling value self._max_samples = max_samples @@ -499,6 +504,11 @@ def _fit( # Store validated integer feature sampling value self._max_features = max_features + # Store sample_weight (needed in _get_estimators_indices). Note that + # we intentionally do not materialize `sample_weight=None` as an array + # of ones to avoid unnecessarily cluttering trained estimator pickles. + self._sample_weight = sample_weight + # Other checks if not self.bootstrap and self.oob_score: raise ValueError("Out of bag estimation only available if bootstrap=True") @@ -552,6 +562,7 @@ def _fit( self, X, y, + sample_weight, seeds[starts[i] : starts[i + 1]], total_n_estimators, verbose=self.verbose, @@ -596,6 +607,7 @@ def _get_estimators_indices(self): self._n_samples, self._max_features, self._max_samples, + self._sample_weight, ) yield feature_indices, sample_indices @@ -726,7 +738,8 @@ class BaggingClassifier(ClassifierMixin, BaseBagging): replacement by default, see `bootstrap` for more details). - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + - If float, then draw `max_samples * X.shape[0]` unweighted samples + or `max_samples * sample_weight.sum()` weighted samples. max_features : int or float, default=1.0 The number of features to draw from X to train each base estimator ( @@ -737,8 +750,10 @@ class BaggingClassifier(ClassifierMixin, BaseBagging): - If float, then draw `max(1, int(max_features * n_features_in_))` features. bootstrap : bool, default=True - Whether samples are drawn with replacement. If False, sampling - without replacement is performed. + Whether samples are drawn with replacement. If False, sampling without + replacement is performed. If fitting with `sample_weight`, it is + strongly recommended to choose True, as only drawing with replacement + will ensure the expected frequency semantics of `sample_weight`. bootstrap_features : bool, default=False Whether features are drawn with replacement. @@ -1245,8 +1260,10 @@ class BaggingRegressor(RegressorMixin, BaseBagging): - If float, then draw `max(1, int(max_features * n_features_in_))` features. bootstrap : bool, default=True - Whether samples are drawn with replacement. If False, sampling - without replacement is performed. + Whether samples are drawn with replacement. If False, sampling without + replacement is performed. If fitting with `sample_weight`, it is + strongly recommended to choose True, as only drawing with replacement + will ensure the expected frequency semantics of `sample_weight`. bootstrap_features : bool, default=False Whether features are drawn with replacement. diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py index 4e5287af7f699..31c5491ccb6c9 100644 --- a/sklearn/ensemble/_iforest.py +++ b/sklearn/ensemble/_iforest.py @@ -20,7 +20,12 @@ from ..utils._chunking import get_chunk_n_rows from ..utils._param_validation import Interval, RealNotInt, StrOptions from ..utils.parallel import Parallel, delayed -from ..utils.validation import _num_samples, check_is_fitted, validate_data +from ..utils.validation import ( + _check_sample_weight, + _num_samples, + check_is_fitted, + validate_data, +) from ._bagging import BaseBagging __all__ = ["IsolationForest"] @@ -317,6 +322,10 @@ def fit(self, X, y=None, sample_weight=None): X = validate_data( self, X, accept_sparse=["csc"], dtype=tree_dtype, ensure_all_finite=False ) + + if sample_weight is not None: + sample_weight = _check_sample_weight(sample_weight, X, dtype=None) + if issparse(X): # Pre-sort indices to avoid that each individual tree of the # ensemble sorts the indices. @@ -350,7 +359,7 @@ def fit(self, X, y=None, sample_weight=None): super()._fit( X, y, - max_samples, + max_samples=max_samples, max_depth=max_depth, sample_weight=sample_weight, check_input=False, diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 2cb9336bfd759..67fb5c763606f 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -5,6 +5,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import re from itertools import cycle, product import joblib @@ -42,7 +43,11 @@ ) from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils import check_random_state -from sklearn.utils._testing import assert_array_almost_equal, assert_array_equal +from sklearn.utils._testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, +) from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS rng = check_random_state(0) @@ -589,28 +594,6 @@ def test_bagging_with_pipeline(): assert isinstance(estimator[0].steps[-1][1].random_state, int) -class DummyZeroEstimator(BaseEstimator): - def fit(self, X, y): - self.classes_ = np.unique(y) - return self - - def predict(self, X): - return self.classes_[np.zeros(X.shape[0], dtype=int)] - - -def test_bagging_sample_weight_unsupported_but_passed(): - estimator = BaggingClassifier(DummyZeroEstimator()) - rng = check_random_state(0) - - estimator.fit(iris.data, iris.target).predict(iris.data) - with pytest.raises(ValueError): - estimator.fit( - iris.data, - iris.target, - sample_weight=rng.randint(10, size=(iris.data.shape[0])), - ) - - def test_warm_start(random_state=42): # Test if fitting incrementally with warm start gives a forest of the # right size and the same results as a normal fit. @@ -692,6 +675,138 @@ def test_warm_start_with_oob_score_fails(): clf.fit(X, y) +def test_warning_bootstrap_sample_weight(): + X, y = iris.data, iris.target + sample_weight = np.ones_like(y) + clf = BaggingClassifier(bootstrap=False) + warn_msg = ( + "When fitting BaggingClassifier with sample_weight " + "it is recommended to use bootstrap=True" + ) + with pytest.warns(UserWarning, match=warn_msg): + clf.fit(X, y, sample_weight=sample_weight) + + X, y = diabetes.data, diabetes.target + sample_weight = np.ones_like(y) + reg = BaggingRegressor(bootstrap=False) + warn_msg = ( + "When fitting BaggingRegressor with sample_weight " + "it is recommended to use bootstrap=True" + ) + with pytest.warns(UserWarning, match=warn_msg): + reg.fit(X, y, sample_weight=sample_weight) + + +def test_invalid_sample_weight_max_samples_bootstrap_combinations(): + X, y = iris.data, iris.target + + # Case 1: small weights and fractional max_samples would lead to sampling + # less than 1 sample, which is not allowed. + clf = BaggingClassifier(max_samples=1.0) + sample_weight = np.ones_like(y) / (2 * len(y)) + expected_msg = ( + r"The total sum of sample weights is 0.5(\d*), which prevents resampling with " + r"a fractional value for max_samples=1\.0\. Either pass max_samples as an " + r"integer or use a larger sample_weight\." + ) + with pytest.raises(ValueError, match=expected_msg): + clf.fit(X, y, sample_weight=sample_weight) + + # Case 2: large weights and bootstrap=False would lead to sampling without + # replacement more than the number of samples, which is not allowed. + clf = BaggingClassifier(bootstrap=False, max_samples=1.0) + sample_weight = np.ones_like(y) + sample_weight[-1] = 2 + expected_msg = re.escape( + "max_samples=151 must be <= n_samples=150 to be able to sample without " + "replacement." + ) + with pytest.raises(ValueError, match=expected_msg): + with pytest.warns( + UserWarning, match="When fitting BaggingClassifier with sample_weight" + ): + clf.fit(X, y, sample_weight=sample_weight) + + +class EstimatorAcceptingSampleWeight(BaseEstimator): + """Fake estimator accepting sample_weight""" + + def fit(self, X, y, sample_weight=None): + """Record values passed during fit""" + self.X_ = X + self.y_ = y + self.sample_weight_ = sample_weight + + def predict(self, X): + pass + + +class EstimatorRejectingSampleWeight(BaseEstimator): + """Fake estimator rejecting sample_weight""" + + def fit(self, X, y): + """Record values passed during fit""" + self.X_ = X + self.y_ = y + + def predict(self, X): + pass + + +@pytest.mark.parametrize("bagging_class", [BaggingRegressor, BaggingClassifier]) +@pytest.mark.parametrize("accept_sample_weight", [False, True]) +@pytest.mark.parametrize("metadata_routing", [False, True]) +@pytest.mark.parametrize("max_samples", [10, 0.8]) +def test_draw_indices_using_sample_weight( + bagging_class, accept_sample_weight, metadata_routing, max_samples +): + X = np.arange(100).reshape(-1, 1) + y = np.repeat([0, 1], 50) + # all indices except 4 and 5 have zero weight + sample_weight = np.zeros(100) + sample_weight[4] = 1 + sample_weight[5] = 2 + if accept_sample_weight: + base_estimator = EstimatorAcceptingSampleWeight() + else: + base_estimator = EstimatorRejectingSampleWeight() + + n_samples, n_features = X.shape + + if isinstance(max_samples, float): + # max_samples passed as a fraction of the input data. Since + # sample_weight are provided, the effective number of samples is the + # sum of the sample weights. + expected_integer_max_samples = int(max_samples * sample_weight.sum()) + else: + expected_integer_max_samples = max_samples + + with config_context(enable_metadata_routing=metadata_routing): + # TODO(slep006): remove block when default routing is implemented + if metadata_routing and accept_sample_weight: + base_estimator = base_estimator.set_fit_request(sample_weight=True) + bagging = bagging_class(base_estimator, max_samples=max_samples, n_estimators=4) + bagging.fit(X, y, sample_weight=sample_weight) + for estimator, samples in zip(bagging.estimators_, bagging.estimators_samples_): + counts = np.bincount(samples, minlength=n_samples) + assert sum(counts) == len(samples) == expected_integer_max_samples + # only indices 4 and 5 should appear + assert np.isin(samples, [4, 5]).all() + if accept_sample_weight: + # sampled indices represented through weighting + assert estimator.X_.shape == (n_samples, n_features) + assert estimator.y_.shape == (n_samples,) + assert_allclose(estimator.X_, X) + assert_allclose(estimator.y_, y) + assert_allclose(estimator.sample_weight_, counts) + else: + # sampled indices represented through indexing + assert estimator.X_.shape == (expected_integer_max_samples, n_features) + assert estimator.y_.shape == (expected_integer_max_samples,) + assert_allclose(estimator.X_, X[samples]) + assert_allclose(estimator.y_, y[samples]) + + def test_oob_score_removed_on_warm_start(): X, y = make_hastie_10_2(n_samples=100, random_state=1) diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index f4ed228ec2f9d..2120c8a0c51f6 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -330,7 +330,7 @@ "y": y, "preserves_metadata": False, "estimator_routing_methods": [ - "fit", + ("fit", ["metadata"]), "predict", "predict_proba", "predict_log_proba", @@ -349,7 +349,7 @@ "X": X, "y": y, "preserves_metadata": False, - "estimator_routing_methods": ["fit", "predict"], + "estimator_routing_methods": [("fit", ["metadata"]), "predict"], }, { "metaestimator": RidgeCV, @@ -459,7 +459,13 @@ - X: X-data to fit and predict - y: y-data to fit - estimator_routing_methods: list of all methods to check for routing metadata - to the sub-estimator + to the sub-estimator. Each value is either a str or a tuple: + - str: the name of the method, all metadata in this method must be routed to the + sub-estimator + - tuple: the name of the method, the second element is a list of metadata keys + to be passed to the sub-estimator. This is useful if certain metadata such as + `sample_weight` are never routed and only consumed, such as in `BaggingClassifier` + and `BaggingRegressor`. - preserves_metadata: - True (default): the metaestimator passes the metadata to the sub-estimator without modification. We check that the values recorded by @@ -562,6 +568,32 @@ def get_init_args(metaestimator_info, sub_estimator_consumes): ) +def filter_metadata_in_routing_methods(estimator_routing_methods): + """Process estimator_routing_methods and return a dict. + + Parameters + ---------- + estimator_routing_methods : list of str or tuple + The estimator_routing_methods info from METAESTIMATORS. + + Returns + ------- + routing_methods : dict + The dictionary is of the form {"method": ["metadata", ...]}. + It specifies the list of metadata keys for each routing method. + By default the list includes `sample_weight` and `metadata`. + """ + res = dict() + for method_spec in estimator_routing_methods: + if isinstance(method_spec, str): + method = method_spec + metadata = ["sample_weight", "metadata"] + else: + method, metadata = method_spec + res[method] = metadata + return res + + def set_requests(obj, *, method_mapping, methods, metadata_name, value=True): """Call `set_{method}_request` on a list of methods from the sub-estimator. @@ -662,10 +694,12 @@ def test_error_on_missing_requests_for_sub_estimator(metaestimator): metaestimator_class = metaestimator["metaestimator"] X = metaestimator["X"] y = metaestimator["y"] - routing_methods = metaestimator["estimator_routing_methods"] + routing_methods = filter_metadata_in_routing_methods( + metaestimator["estimator_routing_methods"] + ) - for method_name in routing_methods: - for key in ["sample_weight", "metadata"]: + for method_name, metadata_keys in routing_methods.items(): + for key in metadata_keys: kwargs, (estimator, _), (scorer, _), *_ = get_init_args( metaestimator, sub_estimator_consumes=True ) @@ -721,12 +755,14 @@ def test_setting_request_on_sub_estimator_removes_error(metaestimator): metaestimator_class = metaestimator["metaestimator"] X = metaestimator["X"] y = metaestimator["y"] - routing_methods = metaestimator["estimator_routing_methods"] + routing_methods = filter_metadata_in_routing_methods( + metaestimator["estimator_routing_methods"] + ) method_mapping = metaestimator.get("method_mapping", {}) preserves_metadata = metaestimator.get("preserves_metadata", True) - for method_name in routing_methods: - for key in ["sample_weight", "metadata"]: + for method_name, metadata_keys in routing_methods.items(): + for key in metadata_keys: val = {"sample_weight": sample_weight, "metadata": metadata}[key] method_kwargs = {key: val} @@ -797,8 +833,9 @@ def set_request(estimator, method_name): metaestimator_class = metaestimator["metaestimator"] X = metaestimator["X"] y = metaestimator["y"] - routing_methods = metaestimator["estimator_routing_methods"] - + routing_methods = filter_metadata_in_routing_methods( + metaestimator["estimator_routing_methods"] + ) for method_name in routing_methods: kwargs, (estimator, _), (_, _), (_, _) = get_init_args( metaestimator, sub_estimator_consumes=False