diff --git a/doc/ensemble.rst b/doc/ensemble.rst index e556e4693..69cc443f3 100644 --- a/doc/ensemble.rst +++ b/doc/ensemble.rst @@ -77,7 +77,9 @@ each tree of the forest will be provided a balanced bootstrap sample :class:`~sklearn.ensemble.RandomForestClassifier`:: >>> from imblearn.ensemble import BalancedRandomForestClassifier - >>> brf = BalancedRandomForestClassifier(n_estimators=100, random_state=0) + >>> brf = BalancedRandomForestClassifier( + ... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True + ... ) >>> brf.fit(X_train, y_train) BalancedRandomForestClassifier(...) >>> y_pred = brf.predict(X_test) diff --git a/doc/whats_new/v0.11.rst b/doc/whats_new/v0.11.rst index 8e54a275b..aa49204f1 100644 --- a/doc/whats_new/v0.11.rst +++ b/doc/whats_new/v0.11.rst @@ -30,6 +30,11 @@ Deprecation and will be removed in version 0.13. Use `categorical_encoder_` instead. :pr:`1000` by :user:`Guillaume Lemaitre `. +- The default of the parameters `sampling_strategy` and `replacement` will change in + :class:`~imblearn.ensemble.BalancedRandomForestClassifier` to follow the + implementation of the original paper. This changes will take effect in version 0.13. + :pr:`1006` by :user:`Guillaume Lemaitre `. + Enhancements ............ diff --git a/examples/applications/plot_impact_imbalanced_classes.py b/examples/applications/plot_impact_imbalanced_classes.py index 278033ebb..3c50e9ed0 100644 --- a/examples/applications/plot_impact_imbalanced_classes.py +++ b/examples/applications/plot_impact_imbalanced_classes.py @@ -319,7 +319,9 @@ rf_clf = make_pipeline( preprocessor_tree, - BalancedRandomForestClassifier(random_state=42, n_jobs=2), + BalancedRandomForestClassifier( + sampling_strategy="all", replacement=True, random_state=42, n_jobs=2 + ), ) # %% diff --git a/examples/ensemble/plot_comparison_ensemble_classifier.py b/examples/ensemble/plot_comparison_ensemble_classifier.py index 0eec5403b..2ed2eb29f 100644 --- a/examples/ensemble/plot_comparison_ensemble_classifier.py +++ b/examples/ensemble/plot_comparison_ensemble_classifier.py @@ -143,7 +143,9 @@ from imblearn.ensemble import BalancedRandomForestClassifier rf = RandomForestClassifier(n_estimators=50, random_state=0) -brf = BalancedRandomForestClassifier(n_estimators=50, random_state=0) +brf = BalancedRandomForestClassifier( + n_estimators=50, sampling_strategy="all", replacement=True, random_state=0 +) rf.fit(X_train, y_train) brf.fit(X_train, y_train) diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index 6e96908bf..18a132591 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -36,10 +36,9 @@ from ..base import _ParamsValidationMixin from ..pipeline import make_pipeline from ..under_sampling import RandomUnderSampler -from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution from ..utils._docstring import _n_jobs_docstring, _random_state_docstring -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Hidden, Interval, StrOptions from ..utils._validation import check_sampling_strategy from ..utils.fixes import _fit_context from ._common import _random_forest_classifier_parameter_constraints @@ -100,7 +99,6 @@ def _local_parallel_build_trees( @Substitution( - sampling_strategy=BaseUnderSampler._sampling_strategy_docstring, n_jobs=_n_jobs_docstring, random_state=_random_state_docstring, ) @@ -193,11 +191,56 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif Whether to use out-of-bag samples to estimate the generalization accuracy. - {sampling_strategy} + sampling_strategy : float, str, dict, callable, default="auto" + Sampling information to sample the data set. + + - When ``float``, it corresponds to the desired ratio of the number of + samples in the minority class over the number of samples in the + majority class after resampling. Therefore, the ratio is expressed as + :math:`\\alpha_{{us}} = N_{{m}} / N_{{rM}}` where :math:`N_{{m}}` is the + number of samples in the minority class and + :math:`N_{{rM}}` is the number of samples in the majority class + after resampling. + + .. warning:: + ``float`` is only available for **binary** classification. An + error is raised for multi-class classification. + + - When ``str``, specify the class targeted by the resampling. The + number of samples in the different classes will be equalized. + Possible choices are: + + ``'majority'``: resample only the majority class; + + ``'not minority'``: resample all classes but the minority class; + + ``'not majority'``: resample all classes but the majority class; + + ``'all'``: resample all classes; + + ``'auto'``: equivalent to ``'not minority'``. + + - When ``dict``, the keys correspond to the targeted classes. The + values correspond to the desired number of samples for each targeted + class. + + - When callable, function taking ``y`` and returns a ``dict``. The keys + correspond to the targeted classes. The values correspond to the + desired number of samples for each class. + + .. versionchanged:: 0.11 + The default of `sampling_strategy` will change from `"auto"` to + `"all"` in version 0.13. This forces to use a bootstrap of the + minority class as proposed in [1]_. replacement : bool, default=False Whether or not to sample randomly with replacement or not. + .. versionchanged:: 0.11 + The default of `replacement` will change from `False` to `True` in + version 0.13. This forces to use a bootstrap of the + minority class and draw with replacement as proposed in [1]_. + {n_jobs} {random_state} @@ -351,7 +394,8 @@ class labels (multi-output problem). >>> X, y = make_classification(n_samples=1000, n_classes=3, ... n_informative=4, weights=[0.2, 0.3, 0.5], ... random_state=0) - >>> clf = BalancedRandomForestClassifier(max_depth=2, random_state=0) + >>> clf = BalancedRandomForestClassifier( + ... sampling_strategy="all", replacement=True, max_depth=2, random_state=0) >>> clf.fit(X, y) BalancedRandomForestClassifier(...) >>> print(clf.feature_importances_) @@ -376,8 +420,9 @@ class labels (multi-output problem). StrOptions({"auto", "majority", "not minority", "not majority", "all"}), dict, callable, + Hidden(StrOptions({"warn"})), ], - "replacement": ["boolean"], + "replacement": ["boolean", Hidden(StrOptions({"warn"}))], } ) @@ -395,8 +440,8 @@ def __init__( min_impurity_decrease=0.0, bootstrap=True, oob_score=False, - sampling_strategy="auto", - replacement=False, + sampling_strategy="warn", + replacement="warn", n_jobs=None, random_state=None, verbose=0, @@ -450,7 +495,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()): self.base_sampler_ = RandomUnderSampler( sampling_strategy=self._sampling_strategy, - replacement=self.replacement, + replacement=self._replacement, ) def _make_sampler_estimator(self, random_state=None): @@ -496,6 +541,31 @@ def fit(self, X, y, sample_weight=None): The fitted instance. """ self._validate_params() + # TODO: remove in 0.13 + if self.sampling_strategy == "warn": + warn( + "The default of `sampling_strategy` will change from `'auto'` to " + "`'all'` in version 0.13. This change will follow the implementation " + "proposed in the original paper. Set to `'all'` to silence this " + "warning and adopt the future behaviour.", + FutureWarning, + ) + self._sampling_strategy = "auto" + else: + self._sampling_strategy = self.sampling_strategy + + if self.replacement == "warn": + warn( + "The default of `replacement` will change from `False` to " + "`True` in version 0.13. This change will follow the implementation " + "proposed in the original paper. Set to `True` to silence this " + "warning and adopt the future behaviour.", + FutureWarning, + ) + self._replacement = False + else: + self._replacement = self.replacement + # Validate or convert input data if issparse(y): raise ValueError("sparse multilabel-indicator for y is not supported.") @@ -533,7 +603,7 @@ def fit(self, X, y, sample_weight=None): if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous: y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE) - if isinstance(self.sampling_strategy, dict): + if isinstance(self._sampling_strategy, dict): self._sampling_strategy = { np.where(self.classes_[0] == key)[0][0]: value for key, value in check_sampling_strategy( @@ -543,7 +613,7 @@ def fit(self, X, y, sample_weight=None): ).items() } else: - self._sampling_strategy = self.sampling_strategy + self._sampling_strategy = self._sampling_strategy if expanded_class_weight is not None: if sample_weight is not None: diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py index c7ae65f85..697722b20 100644 --- a/imblearn/ensemble/tests/test_forest.py +++ b/imblearn/ensemble/tests/test_forest.py @@ -28,7 +28,9 @@ def imbalanced_dataset(): def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset): - brf = BalancedRandomForestClassifier(n_estimators=5) + brf = BalancedRandomForestClassifier( + n_estimators=5, sampling_strategy="all", replacement=True + ) brf.fit(*imbalanced_dataset) with pytest.raises(ValueError, match="must be larger or equal to"): @@ -44,7 +46,12 @@ def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset): def test_balanced_random_forest(imbalanced_dataset): n_estimators = 10 - brf = BalancedRandomForestClassifier(n_estimators=n_estimators, random_state=0) + brf = BalancedRandomForestClassifier( + n_estimators=n_estimators, + random_state=0, + sampling_strategy="all", + replacement=True, + ) brf.fit(*imbalanced_dataset) assert len(brf.samplers_) == n_estimators @@ -56,7 +63,12 @@ def test_balanced_random_forest(imbalanced_dataset): def test_balanced_random_forest_attributes(imbalanced_dataset): X, y = imbalanced_dataset n_estimators = 10 - brf = BalancedRandomForestClassifier(n_estimators=n_estimators, random_state=0) + brf = BalancedRandomForestClassifier( + n_estimators=n_estimators, + random_state=0, + sampling_strategy="all", + replacement=True, + ) brf.fit(X, y) for idx in range(n_estimators): @@ -80,7 +92,9 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset): rng = np.random.RandomState(42) X, y = imbalanced_dataset sample_weight = rng.rand(y.shape[0]) - brf = BalancedRandomForestClassifier(n_estimators=5, random_state=0) + brf = BalancedRandomForestClassifier( + n_estimators=5, random_state=0, sampling_strategy="all", replacement=True + ) brf.fit(X, y, sample_weight) @@ -95,6 +109,8 @@ def test_balanced_random_forest_oob(imbalanced_dataset): random_state=0, n_estimators=1000, min_samples_leaf=2, + sampling_strategy="all", + replacement=True, ) est.fit(X_train, y_train) @@ -104,14 +120,19 @@ def test_balanced_random_forest_oob(imbalanced_dataset): # Check warning if not enough estimators est = BalancedRandomForestClassifier( - oob_score=True, random_state=0, n_estimators=1, bootstrap=True + oob_score=True, + random_state=0, + n_estimators=1, + bootstrap=True, + sampling_strategy="all", + replacement=True, ) with pytest.warns(UserWarning) and np.errstate(divide="ignore", invalid="ignore"): est.fit(X, y) def test_balanced_random_forest_grid_search(imbalanced_dataset): - brf = BalancedRandomForestClassifier() + brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True) grid = GridSearchCV(brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3) grid.fit(*imbalanced_dataset) @@ -127,6 +148,8 @@ def test_little_tree_with_small_max_samples(): n_estimators=1, random_state=rng, max_samples=None, + sampling_strategy="all", + replacement=True, ) # Second fit with max samples restricted to just 2 @@ -134,6 +157,8 @@ def test_little_tree_with_small_max_samples(): n_estimators=1, random_state=rng, max_samples=2, + sampling_strategy="all", + replacement=True, ) est1.fit(X, y) @@ -147,11 +172,13 @@ def test_little_tree_with_small_max_samples(): def test_balanced_random_forest_pruning(imbalanced_dataset): - brf = BalancedRandomForestClassifier() + brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True) brf.fit(*imbalanced_dataset) n_nodes_no_pruning = brf.estimators_[0].tree_.node_count - brf_pruned = BalancedRandomForestClassifier(ccp_alpha=0.015) + brf_pruned = BalancedRandomForestClassifier( + ccp_alpha=0.015, sampling_strategy="all", replacement=True + ) brf_pruned.fit(*imbalanced_dataset) n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count @@ -168,7 +195,12 @@ def test_balanced_random_forest_oob_binomial(ratio): X = np.arange(n_samples).reshape(-1, 1) y = rng.binomial(1, ratio, size=n_samples) - erf = BalancedRandomForestClassifier(oob_score=True, random_state=42) + erf = BalancedRandomForestClassifier( + oob_score=True, + random_state=42, + sampling_strategy="not minority", + replacement=False, + ) erf.fit(X, y) assert np.abs(erf.oob_score_ - 0.5) < 0.1 @@ -176,7 +208,9 @@ def test_balanced_random_forest_oob_binomial(ratio): def test_balanced_bagging_classifier_n_features(): """Check that we raise a FutureWarning when accessing `n_features_`.""" X, y = load_iris(return_X_y=True) - estimator = BalancedRandomForestClassifier().fit(X, y) + estimator = BalancedRandomForestClassifier( + sampling_strategy="all", replacement=True + ).fit(X, y) with pytest.warns(FutureWarning, match="`n_features_` was deprecated"): estimator.n_features_ @@ -184,9 +218,24 @@ def test_balanced_bagging_classifier_n_features(): @pytest.mark.skipif( sklearn_version < parse_version("1.2"), reason="requires scikit-learn>=1.2" ) -def test_balanced_bagging_classifier_base_estimator(): +def test_balanced_random_forest_classifier_base_estimator(): """Check that we raise a FutureWarning when accessing `base_estimator_`.""" X, y = load_iris(return_X_y=True) - estimator = BalancedRandomForestClassifier().fit(X, y) + estimator = BalancedRandomForestClassifier( + sampling_strategy="all", replacement=True + ).fit(X, y) with pytest.warns(FutureWarning, match="`base_estimator_` was deprecated"): estimator.base_estimator_ + + +# TODO: remove in 0.13 +def test_balanced_random_forest_change_behaviour(imbalanced_dataset): + """Check that we raise a change of behaviour for the parameters `sampling_strategy` + and `replacement`. + """ + estimator = BalancedRandomForestClassifier(sampling_strategy="all") + with pytest.warns(FutureWarning, match="The default of `replacement`"): + estimator.fit(*imbalanced_dataset) + estimator = BalancedRandomForestClassifier(replacement=True) + with pytest.warns(FutureWarning, match="The default of `sampling_strategy`"): + estimator.fit(*imbalanced_dataset) diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 0a7915a44..e8f6f7fb4 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -75,6 +75,10 @@ def _set_checking_parameters(estimator): ) if name == "KMeansSMOTE": estimator.set_params(kmeans_estimator=12) + if name == "BalancedRandomForestClassifier": + # TODO: remove in 0.13 + # future default in 0.13 + estimator.set_params(replacement=True, sampling_strategy="all") def _yield_sampler_checks(sampler):