Skip to content

FIX incorporate resampling when computing OOB score in BRF #656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ Bug fixes
`cross_val_predict` is used to take advantage of the parallelism.
:pr:`599` by :user:`Shihab Shahriar Khan <Shihab-Shahriar>`.

- Fix a bug in :class:`imblearn.ensemble.BalancedRandomForestClassifier`
leading to a wrong computation of the OOB score.
:pr:`656` by :user:`Guillaume Lemaitre <glemaitre>`.

Maintenance
...........

Expand Down
61 changes: 61 additions & 0 deletions imblearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.ensemble._base import _set_random_states
from sklearn.ensemble._forest import _get_n_samples_bootstrap
from sklearn.ensemble._forest import _parallel_build_trees
from sklearn.ensemble._forest import _generate_unsampled_indices
from sklearn.exceptions import DataConversionWarning
from sklearn.tree import DecisionTreeClassifier
from sklearn.utils import check_array
Expand Down Expand Up @@ -545,5 +546,65 @@ def fit(self, X, y, sample_weight=None):

return self

def _set_oob_score(self, X, y):
"""Compute out-of-bag score."""
X = check_array(X, dtype=DTYPE, accept_sparse='csr')

n_classes_ = self.n_classes_
n_samples = y.shape[0]

oob_decision_function = []
oob_score = 0.0
predictions = [np.zeros((n_samples, n_classes_[k]))
for k in range(self.n_outputs_)]

for sampler, estimator in zip(self.samplers_, self.estimators_):
X_resample = X[sampler.sample_indices_]
y_resample = y[sampler.sample_indices_]

n_sample_subset = y_resample.shape[0]
n_samples_bootstrap = _get_n_samples_bootstrap(
n_sample_subset, self.max_samples
)

unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_sample_subset, n_samples_bootstrap
)
p_estimator = estimator.predict_proba(
X_resample[unsampled_indices, :], check_input=False
)

if self.n_outputs_ == 1:
p_estimator = [p_estimator]

for k in range(self.n_outputs_):
indices = sampler.sample_indices_[unsampled_indices]
predictions[k][indices, :] += p_estimator[k]

for k in range(self.n_outputs_):
if (predictions[k].sum(axis=1) == 0).any():
warn("Some inputs do not have OOB scores. "
"This probably means too few trees were used "
"to compute any reliable oob estimates.")

with np.errstate(invalid="ignore", divide="ignore"):
# with the resampling, we are likely to have rows not included
# for the OOB score leading to division by zero
decision = (predictions[k] /
predictions[k].sum(axis=1)[:, np.newaxis])
mask_scores = np.isnan(np.sum(decision, axis=1))
oob_decision_function.append(decision)
oob_score += np.mean(
y[~mask_scores, k] == np.argmax(predictions[k][~mask_scores],
axis=1),
axis=0)

if self.n_outputs_ == 1:
self.oob_decision_function_ = oob_decision_function[0]
else:
self.oob_decision_function_ = oob_decision_function

self.oob_score_ = oob_score / self.n_outputs_

def _more_tags(self):
return {"multioutput": False}
23 changes: 20 additions & 3 deletions imblearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_array_equal

Expand Down Expand Up @@ -108,13 +109,16 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
brf.fit(X, y, sample_weight)


@pytest.mark.filterwarnings("ignore:Some inputs do not have OOB scores")
def test_balanced_random_forest_oob(imbalanced_dataset):
X, y = imbalanced_dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=42, stratify=y
)
est = BalancedRandomForestClassifier(oob_score=True, random_state=0)

n_samples = X.shape[0]
est.fit(X[: n_samples // 2, :], y[: n_samples // 2])
test_score = est.score(X[n_samples // 2:, :], y[n_samples // 2:])
est.fit(X_train, y_train)
test_score = est.score(X_test, y_test)

assert abs(test_score - est.oob_score_) < 0.1

Expand Down Expand Up @@ -176,3 +180,16 @@ def test_balanced_random_forest_pruning(imbalanced_dataset):
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count

assert n_nodes_no_pruning > n_nodes_pruning


def test_balanced_random_forest_oob_binomial():
# Regression test for #655: check that the oob score is closed to 0.5
# a binomial experiment.
rng = np.random.RandomState(42)
n_samples = 1000
X = np.arange(n_samples).reshape(-1, 1)
y = rng.binomial(1, 0.5, size=n_samples)

erf = BalancedRandomForestClassifier(oob_score=True, random_state=42)
erf.fit(X, y)
assert np.abs(erf.oob_score_ - 0.5) < 0.05