Skip to content

[MRG+1] EHN Add bootstrap sample size limit to forest ensembles #14682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Sep 20, 2019

Conversation

notmatthancock
Copy link
Contributor

Adds a max_samples kwarg to forest ensembles that limits the size of the bootstrap samples used to train each estimator. This PR is intended to supersede two previous stalled PRs (#5963 and #9645), which have not been touched for a couple of years. This PR adds unit tests for the various functionalities.

Why is this useful?

When training a random forest classifier on a large dataset with correlated/redundant examples, training on the entire dataset is memory intensive and unnecessary. Further, this results models that occupy an unwieldy amount of memory. As an example consider training on image patches for a segmentation-as-classification problem. It would be useful in this situation to train only on a subset of the available image patches because it's expected that an image patch obtained at one location is highly correlated with the patch obtained one pixel over. Limiting the size of each bootstrap sample to train each estimator is useful in this and similar applications.

Pickled model disk space comparison

Here's a simple test script to show the difference between occupied disk space of the pickled model, using the full dataset size for each bootstrap vs. using just a bootstrap size of 1 (obviously this is dramatic):

File size `max_samples=None`: 0.154GB
File size `max_samples=1`: 5.286e-05GB
import os
import pickle

import numpy as np
from sklearn.ensemble import RandomForestClassifier


rs = np.random.RandomState(1234)
X = rs.randn(100000, 1000)
y = rs.randn(X.shape[0]) > 0

rfc = RandomForestClassifier(
    n_estimators=100, random_state=rs)
rfc.fit(X, y)
with open('rfc.pkl', 'wb') as f:
    pickle.dump(rfc, f)
size = os.stat('rfc.pkl').st_size / 1e9
print("File size `max_samples=None`: {}GB".format(size))

rfc = RandomForestClassifier(
    n_estimators=100, random_state=rs, max_samples=1)
rfc.fit(X, y)
with open('rfc.pkl', 'wb') as f:
    pickle.dump(rfc, f)
size = os.stat('rfc.pkl').st_size / 1e9
print("File size `max_samples=1`: {}GB".format(size))

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should also add the parameter to RandomForestRegressor, isn't it?

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional comments regarding the tests.

@@ -1330,3 +1330,93 @@ def test_forest_degenerate_feature_importances():
gbr = RandomForestRegressor(n_estimators=10).fit(X, y)
assert_array_equal(gbr.feature_importances_,
np.zeros(10, dtype=np.float64))


def test__get_n_bootstrap_samples():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually don't test a private function. Instead we would test through the different estimator (RandomForest, ExtraTrees)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can parametrize easily this case and you can check. You can make a test function where want to check that errors are raised properly (you can parametrize as well).

The other behavior should be done by fitting the estimator and check that we fitted on a subset of data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check that we fitted on a subset of data.

where do you get that info?

As far as I can see, the sample indices for each bootstrap are not stored anywhere, but translated into sample weights and passed to the estimator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(exception check unit tests refactored in: 081b7b7)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do you get that info?

Yep I thought that we were having a private function to get back the indices but it does not look like so simple to do without making some hacks. I would need a bit more time to think about how to test this.

unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
estimator.random_state, n_samples, n_bootstrap_samples)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman here we will test on the left out samples. Since we used max_samples it means that we will use a lot of samples for the OOB. Is it something that we want?

If this is the case, we should add a test to make sure that we have the right complement in case max_samples is not None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a problem with having a large OOB sample for test? Testing with trees isn't fast, but is a lot faster than training...?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing with trees isn't fast, but is a lot faster than training...?

True. Should not be an issue then

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sample is much smaller than the dataset I suppose it may be an issue. Maybe we should make a rule that the oob sample is constrained to be no larger than the training set... But that may confuse users trying different values for this parameter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the sample is much smaller than the dataset

I can see this situation easily arising, e.g, your dataset is 106 examples and you want to to fit with say, max_samples=1000 and n_estimators=1000.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of additional comments

@Shihab-Shahriar
Copy link

Hello, does this actually limit sample size given to each base tree? I noticed few days back that _parallel_build_trees doesn't actually sample training instances. To inject randomness, it uses bootstrap as a way of re-weighting the sample_weight of instances. In other words, whole dataset is passed to each base tree, but with different sample_weight.

I tried putting a print(len(y)) just before the tree.fit call when bootstrap=True in your code, and it confirmed my suspicion. Please let me know if there's any error in my reasoning.

@glemaitre
Copy link
Member

Hello, does this actually limit sample size given to each base tree? I noticed few days back that _parallel_build_trees doesn't actually sample training instances. To inject randomness, it uses bootstrap as a way of re-weighting the sample_weight of instances. In other words, whole dataset is passed to each base tree, but with different sample_weight.

It is the definition of a bootstrap sample: sampling with replacement so n_samples=X.shape[0].
Now, we allow sampling with replacement allowing n_samples < X.shape[0].

@glemaitre glemaitre self-requested a review September 9, 2019 11:34
@glemaitre
Copy link
Member

@glemaitre, Sorry about that particular choice of word. I realize this may have sounded a lot different than I originally intended.

It is just to be certain that we don't document something wrong :)

@glemaitre glemaitre self-requested a review September 13, 2019 08:48
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from the small pytest.raises LGTM.

@glemaitre
Copy link
Member

FYI: codecov seems to be wrong. All the diff code is covered.

@glemaitre glemaitre changed the title Add bootstrap sample size limit to forest ensembles [MRG+1] EHN Add bootstrap sample size limit to forest ensembles Sep 13, 2019
@glemaitre
Copy link
Member

Thanks @notmatthancock, let's wait for a second review before merging.

Maybe @jnothman @adrinjalali could have a look

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with the implementation. I'm wondering if it needs mentioning in the documentation for bootstrap... This is no longer quite a bootstrap, but I think former proposals reused that parameter name.

If we are happy with a new parameter, this lgtm

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM, thanks @notmatthancock

if max_samples is None:
return n_samples

if isinstance(max_samples, numbers.Integral):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be consistent w.r.t how we treat these fractions. For instance, in optics, we have:

def _validate_size(size, n_samples, param_name):
if size <= 0 or (size !=
int(size)
and size > 1):
raise ValueError('%s must be a positive integer '
'or a float between 0 and 1. Got %r' %
(param_name, size))
elif size > n_samples:
raise ValueError('%s must be no greater than the'
' number of samples (%d). Got %d' %
(param_name, n_samples, size))

And then 1 always means 100% of the data, at least in optics. Do we have a similar semantics in other places?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With PCA, n_components=1 means 1 components while n_components<1 will be a percentage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excluding 1 from float avoid issue with float comparison as well.

to train each base estimator.
- If None (default), then draw `X.shape[0]` samples.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the behavior of the float, but we need to document that here, i.e. explicitly say that if float, it must be in (0, 1)

@glemaitre glemaitre merged commit 746efb5 into scikit-learn:master Sep 20, 2019
@glemaitre
Copy link
Member

@notmatthancock I made the small change requested by @adrinjalali and merged.
Thanks a lot for your contribution.

@notmatthancock
Copy link
Contributor Author

Likewise @glemaitre, thanks to you and @jnothman for the helpful review comments.

sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 9, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 9, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
sebp added a commit to sebp/scikit-survival that referenced this pull request Apr 10, 2020
- Deprecate presort (scikit-learn/scikit-learn#14907)
- Add Minimal Cost-Complexity Pruning to Decision Trees (scikit-learn/scikit-learn#12887)
- Add bootstrap sample size limit to forest ensembles (scikit-learn/scikit-learn#14682)
- Fix deprecated imports (scikit-learn/scikit-learn#9250)

Do not add ccp_alpha to SurvivalTree, because
it relies node_impurity, which is not set for SurvivalTree.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants