Skip to content

API change default of replacement and sampling_strategy in BRF #1006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ each tree of the forest will be provided a balanced bootstrap sample
:class:`~sklearn.ensemble.RandomForestClassifier`::

>>> from imblearn.ensemble import BalancedRandomForestClassifier
>>> brf = BalancedRandomForestClassifier(n_estimators=100, random_state=0)
>>> brf = BalancedRandomForestClassifier(
... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True
... )
>>> brf.fit(X_train, y_train)
BalancedRandomForestClassifier(...)
>>> y_pred = brf.predict(X_test)
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v0.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ Deprecation
and will be removed in version 0.13. Use `categorical_encoder_` instead.
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.

- The default of the parameters `sampling_strategy` and `replacement` will change in
:class:`~imblearn.ensemble.BalancedRandomForestClassifier` to follow the
implementation of the original paper. This changes will take effect in version 0.13.
:pr:`1006` by :user:`Guillaume Lemaitre <glemaitre>`.

Enhancements
............

Expand Down
4 changes: 3 additions & 1 deletion examples/applications/plot_impact_imbalanced_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@

rf_clf = make_pipeline(
preprocessor_tree,
BalancedRandomForestClassifier(random_state=42, n_jobs=2),
BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True, random_state=42, n_jobs=2
),
)

# %%
Expand Down
4 changes: 3 additions & 1 deletion examples/ensemble/plot_comparison_ensemble_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@
from imblearn.ensemble import BalancedRandomForestClassifier

rf = RandomForestClassifier(n_estimators=50, random_state=0)
brf = BalancedRandomForestClassifier(n_estimators=50, random_state=0)
brf = BalancedRandomForestClassifier(
n_estimators=50, sampling_strategy="all", replacement=True, random_state=0
)

rf.fit(X_train, y_train)
brf.fit(X_train, y_train)
Expand Down
92 changes: 81 additions & 11 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@
from ..base import _ParamsValidationMixin
from ..pipeline import make_pipeline
from ..under_sampling import RandomUnderSampler
from ..under_sampling.base import BaseUnderSampler
from ..utils import Substitution
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
from ..utils._param_validation import Interval, StrOptions
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._validation import check_sampling_strategy
from ..utils.fixes import _fit_context
from ._common import _random_forest_classifier_parameter_constraints
Expand Down Expand Up @@ -100,7 +99,6 @@ def _local_parallel_build_trees(


@Substitution(
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
n_jobs=_n_jobs_docstring,
random_state=_random_state_docstring,
)
Expand Down Expand Up @@ -193,11 +191,56 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
Whether to use out-of-bag samples to estimate
the generalization accuracy.

{sampling_strategy}
sampling_strategy : float, str, dict, callable, default="auto"
Sampling information to sample the data set.

- When ``float``, it corresponds to the desired ratio of the number of
samples in the minority class over the number of samples in the
majority class after resampling. Therefore, the ratio is expressed as
:math:`\\alpha_{{us}} = N_{{m}} / N_{{rM}}` where :math:`N_{{m}}` is the
number of samples in the minority class and
:math:`N_{{rM}}` is the number of samples in the majority class
after resampling.

.. warning::
``float`` is only available for **binary** classification. An
error is raised for multi-class classification.

- When ``str``, specify the class targeted by the resampling. The
number of samples in the different classes will be equalized.
Possible choices are:

``'majority'``: resample only the majority class;

``'not minority'``: resample all classes but the minority class;

``'not majority'``: resample all classes but the majority class;

``'all'``: resample all classes;

``'auto'``: equivalent to ``'not minority'``.

- When ``dict``, the keys correspond to the targeted classes. The
values correspond to the desired number of samples for each targeted
class.

- When callable, function taking ``y`` and returns a ``dict``. The keys
correspond to the targeted classes. The values correspond to the
desired number of samples for each class.

.. versionchanged:: 0.11
The default of `sampling_strategy` will change from `"auto"` to
`"all"` in version 0.13. This forces to use a bootstrap of the
minority class as proposed in [1]_.

replacement : bool, default=False
Whether or not to sample randomly with replacement or not.

.. versionchanged:: 0.11
The default of `replacement` will change from `False` to `True` in
version 0.13. This forces to use a bootstrap of the
minority class and draw with replacement as proposed in [1]_.

{n_jobs}

{random_state}
Expand Down Expand Up @@ -351,7 +394,8 @@ class labels (multi-output problem).
>>> X, y = make_classification(n_samples=1000, n_classes=3,
... n_informative=4, weights=[0.2, 0.3, 0.5],
... random_state=0)
>>> clf = BalancedRandomForestClassifier(max_depth=2, random_state=0)
>>> clf = BalancedRandomForestClassifier(
... sampling_strategy="all", replacement=True, max_depth=2, random_state=0)
>>> clf.fit(X, y)
BalancedRandomForestClassifier(...)
>>> print(clf.feature_importances_)
Expand All @@ -376,8 +420,9 @@ class labels (multi-output problem).
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
dict,
callable,
Hidden(StrOptions({"warn"})),
],
"replacement": ["boolean"],
"replacement": ["boolean", Hidden(StrOptions({"warn"}))],
}
)

Expand All @@ -395,8 +440,8 @@ def __init__(
min_impurity_decrease=0.0,
bootstrap=True,
oob_score=False,
sampling_strategy="auto",
replacement=False,
sampling_strategy="warn",
replacement="warn",
n_jobs=None,
random_state=None,
verbose=0,
Expand Down Expand Up @@ -450,7 +495,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):

self.base_sampler_ = RandomUnderSampler(
sampling_strategy=self._sampling_strategy,
replacement=self.replacement,
replacement=self._replacement,
)

def _make_sampler_estimator(self, random_state=None):
Expand Down Expand Up @@ -496,6 +541,31 @@ def fit(self, X, y, sample_weight=None):
The fitted instance.
"""
self._validate_params()
# TODO: remove in 0.13
if self.sampling_strategy == "warn":
warn(
"The default of `sampling_strategy` will change from `'auto'` to "
"`'all'` in version 0.13. This change will follow the implementation "
"proposed in the original paper. Set to `'all'` to silence this "
"warning and adopt the future behaviour.",
FutureWarning,
)
self._sampling_strategy = "auto"
else:
self._sampling_strategy = self.sampling_strategy

if self.replacement == "warn":
warn(
"The default of `replacement` will change from `False` to "
"`True` in version 0.13. This change will follow the implementation "
"proposed in the original paper. Set to `True` to silence this "
"warning and adopt the future behaviour.",
FutureWarning,
)
self._replacement = False
else:
self._replacement = self.replacement

# Validate or convert input data
if issparse(y):
raise ValueError("sparse multilabel-indicator for y is not supported.")
Expand Down Expand Up @@ -533,7 +603,7 @@ def fit(self, X, y, sample_weight=None):
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE)

if isinstance(self.sampling_strategy, dict):
if isinstance(self._sampling_strategy, dict):
self._sampling_strategy = {
np.where(self.classes_[0] == key)[0][0]: value
for key, value in check_sampling_strategy(
Expand All @@ -543,7 +613,7 @@ def fit(self, X, y, sample_weight=None):
).items()
}
else:
self._sampling_strategy = self.sampling_strategy
self._sampling_strategy = self._sampling_strategy

if expanded_class_weight is not None:
if sample_weight is not None:
Expand Down
73 changes: 61 additions & 12 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def imbalanced_dataset():


def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):
brf = BalancedRandomForestClassifier(n_estimators=5)
brf = BalancedRandomForestClassifier(
n_estimators=5, sampling_strategy="all", replacement=True
)
brf.fit(*imbalanced_dataset)

with pytest.raises(ValueError, match="must be larger or equal to"):
Expand All @@ -44,7 +46,12 @@ def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):

def test_balanced_random_forest(imbalanced_dataset):
n_estimators = 10
brf = BalancedRandomForestClassifier(n_estimators=n_estimators, random_state=0)
brf = BalancedRandomForestClassifier(
n_estimators=n_estimators,
random_state=0,
sampling_strategy="all",
replacement=True,
)
brf.fit(*imbalanced_dataset)

assert len(brf.samplers_) == n_estimators
Expand All @@ -56,7 +63,12 @@ def test_balanced_random_forest(imbalanced_dataset):
def test_balanced_random_forest_attributes(imbalanced_dataset):
X, y = imbalanced_dataset
n_estimators = 10
brf = BalancedRandomForestClassifier(n_estimators=n_estimators, random_state=0)
brf = BalancedRandomForestClassifier(
n_estimators=n_estimators,
random_state=0,
sampling_strategy="all",
replacement=True,
)
brf.fit(X, y)

for idx in range(n_estimators):
Expand All @@ -80,7 +92,9 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
rng = np.random.RandomState(42)
X, y = imbalanced_dataset
sample_weight = rng.rand(y.shape[0])
brf = BalancedRandomForestClassifier(n_estimators=5, random_state=0)
brf = BalancedRandomForestClassifier(
n_estimators=5, random_state=0, sampling_strategy="all", replacement=True
)
brf.fit(X, y, sample_weight)


Expand All @@ -95,6 +109,8 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
random_state=0,
n_estimators=1000,
min_samples_leaf=2,
sampling_strategy="all",
replacement=True,
)

est.fit(X_train, y_train)
Expand All @@ -104,14 +120,19 @@ def test_balanced_random_forest_oob(imbalanced_dataset):

# Check warning if not enough estimators
est = BalancedRandomForestClassifier(
oob_score=True, random_state=0, n_estimators=1, bootstrap=True
oob_score=True,
random_state=0,
n_estimators=1,
bootstrap=True,
sampling_strategy="all",
replacement=True,
)
with pytest.warns(UserWarning) and np.errstate(divide="ignore", invalid="ignore"):
est.fit(X, y)


def test_balanced_random_forest_grid_search(imbalanced_dataset):
brf = BalancedRandomForestClassifier()
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
grid = GridSearchCV(brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3)
grid.fit(*imbalanced_dataset)

Expand All @@ -127,13 +148,17 @@ def test_little_tree_with_small_max_samples():
n_estimators=1,
random_state=rng,
max_samples=None,
sampling_strategy="all",
replacement=True,
)

# Second fit with max samples restricted to just 2
est2 = BalancedRandomForestClassifier(
n_estimators=1,
random_state=rng,
max_samples=2,
sampling_strategy="all",
replacement=True,
)

est1.fit(X, y)
Expand All @@ -147,11 +172,13 @@ def test_little_tree_with_small_max_samples():


def test_balanced_random_forest_pruning(imbalanced_dataset):
brf = BalancedRandomForestClassifier()
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
brf.fit(*imbalanced_dataset)
n_nodes_no_pruning = brf.estimators_[0].tree_.node_count

brf_pruned = BalancedRandomForestClassifier(ccp_alpha=0.015)
brf_pruned = BalancedRandomForestClassifier(
ccp_alpha=0.015, sampling_strategy="all", replacement=True
)
brf_pruned.fit(*imbalanced_dataset)
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count

Expand All @@ -168,25 +195,47 @@ def test_balanced_random_forest_oob_binomial(ratio):
X = np.arange(n_samples).reshape(-1, 1)
y = rng.binomial(1, ratio, size=n_samples)

erf = BalancedRandomForestClassifier(oob_score=True, random_state=42)
erf = BalancedRandomForestClassifier(
oob_score=True,
random_state=42,
sampling_strategy="not minority",
replacement=False,
)
erf.fit(X, y)
assert np.abs(erf.oob_score_ - 0.5) < 0.1


def test_balanced_bagging_classifier_n_features():
"""Check that we raise a FutureWarning when accessing `n_features_`."""
X, y = load_iris(return_X_y=True)
estimator = BalancedRandomForestClassifier().fit(X, y)
estimator = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True
).fit(X, y)
with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
estimator.n_features_


@pytest.mark.skipif(
sklearn_version < parse_version("1.2"), reason="requires scikit-learn>=1.2"
)
def test_balanced_bagging_classifier_base_estimator():
def test_balanced_random_forest_classifier_base_estimator():
"""Check that we raise a FutureWarning when accessing `base_estimator_`."""
X, y = load_iris(return_X_y=True)
estimator = BalancedRandomForestClassifier().fit(X, y)
estimator = BalancedRandomForestClassifier(
sampling_strategy="all", replacement=True
).fit(X, y)
with pytest.warns(FutureWarning, match="`base_estimator_` was deprecated"):
estimator.base_estimator_


# TODO: remove in 0.13
def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
"""Check that we raise a change of behaviour for the parameters `sampling_strategy`
and `replacement`.
"""
estimator = BalancedRandomForestClassifier(sampling_strategy="all")
with pytest.warns(FutureWarning, match="The default of `replacement`"):
estimator.fit(*imbalanced_dataset)
estimator = BalancedRandomForestClassifier(replacement=True)
with pytest.warns(FutureWarning, match="The default of `sampling_strategy`"):
estimator.fit(*imbalanced_dataset)
4 changes: 4 additions & 0 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def _set_checking_parameters(estimator):
)
if name == "KMeansSMOTE":
estimator.set_params(kmeans_estimator=12)
if name == "BalancedRandomForestClassifier":
# TODO: remove in 0.13
# future default in 0.13
estimator.set_params(replacement=True, sampling_strategy="all")


def _yield_sampler_checks(sampler):
Expand Down