Skip to content

FIX max_samples was computed on X instead of X_resampled #661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions doc/whats_new/v0.6.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
.. _changes_0_6_1:

Version 0.6.1
==============

**In Development**

This is a bug-fix release to primarily resolve some packaging issues in version
0.6.0. It also includes minor documentation improvements and some bug fixes.

Changelog
---------

Bug fixes
.........

- Fix a bug in :class:`imblearn.ensemble.BalancedRandomForestClassifier`
leading to a wrong number of samples used during fitting due `max_samples`
and therefore a bad computation of the OOB score.
:pr:`656` by :user:`Guillaume Lemaitre <glemaitre>`.

.. _changes_0_6:

Version 0.6.0
Expand Down
5 changes: 5 additions & 0 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def _local_parallel_build_trees(
X_resampled, y_resampled = sampler.fit_resample(X, y)
if sample_weight is not None:
sample_weight = _safe_indexing(sample_weight, sampler.sample_indices_)
if _get_n_samples_bootstrap is not None:
n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0])
tree = _parallel_build_trees(
tree,
forest,
Expand Down Expand Up @@ -214,6 +216,9 @@ class BalancedRandomForestClassifier(RandomForestClassifier):
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
`max_samples` should be in the interval `(0, 1)`.
Be aware that the final number samples used will be the minimum between
the number of samples given in `max_samples` and the number of samples
obtained after resampling.

.. versionadded:: 0.22
Added in `scikit-learn` in 0.22
Expand Down
12 changes: 8 additions & 4 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=42, stratify=y
)
est = BalancedRandomForestClassifier(oob_score=True, random_state=0)
est = BalancedRandomForestClassifier(
oob_score=True, random_state=0, n_estimators=1000
)

est.fit(X_train, y_train)
test_score = est.score(X_test, y_test)
Expand Down Expand Up @@ -182,14 +184,16 @@ def test_balanced_random_forest_pruning(imbalanced_dataset):
assert n_nodes_no_pruning > n_nodes_pruning


def test_balanced_random_forest_oob_binomial():
@pytest.mark.parametrize("ratio", [0.5, 0.1])
@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores")
def test_balanced_random_forest_oob_binomial(ratio):
# Regression test for #655: check that the oob score is closed to 0.5
# a binomial experiment.
rng = np.random.RandomState(42)
n_samples = 1000
X = np.arange(n_samples).reshape(-1, 1)
y = rng.binomial(1, 0.5, size=n_samples)
y = rng.binomial(1, ratio, size=n_samples)

erf = BalancedRandomForestClassifier(oob_score=True, random_state=42)
erf.fit(X, y)
assert np.abs(erf.oob_score_ - 0.5) < 0.05
assert np.abs(erf.oob_score_ - 0.5) < 0.1