Skip to content

[MRG] EHN: Change default n_estimators to 100 for random forest #11542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 17, 2018

Conversation

annaayzenshtat
Copy link
Contributor

Reference Issues/PRs

Fixes #11128.

What does this implement/fix? Explain your changes.

Issues deprecation warning message for the default n_estimators parameter for the forest classifiers. Test added for the warning message when the default parameter is used.

Any other comments?

@amueller
Copy link
Member

Is this based on #11172? The contributor there seems to have addressed the comments there yesterday...

@amueller
Copy link
Member

though it looks like #11172 is still not right...

@@ -758,6 +763,10 @@ class RandomForestClassifier(ForestClassifier):
n_estimators : integer, optional (default=10)
The number of trees in the forest.

.. deprecated:: 0.20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be "versionchanged" not "deprecated"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

versionchanged as one long word, no spaces?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. git grep versionchanged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on that.

@@ -1228,3 +1229,23 @@ def test_min_impurity_decrease():
# Simply check if the parameter is passed on correctly. Tree tests
# will suffice for the actual working of this param
assert_equal(tree.min_impurity_decrease, 0.1)


def test_nestimators_future_warning():
Copy link
Member

@amueller amueller Jul 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to you pytest.parametrize as above instead of the loop, which will run each estimator as a separate test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI:

@pytest.mark.parametrize('forest', [RandomForestClassifier(), RandomForestRegressor(),
                                    ExtraTreesClassifier(), ExtraTreesRegressor(),
                                    RandomTreesEmbedding()])
def test_n_estimators_future_warning(estimator):
    ....
    estimator.fit(X, y)
    ....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better to parametrize with classes,

@pytest.mark.parametrize('forest', [RandomForestClassifier, RandomForestRegressor,
                         [...]

then create the corresponding instances inside the test -- this works better for getting a human readable test name with pytest..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

@amueller
Copy link
Member

though it looks like #11172 is still not right...

This looks pretty good. Ideally you'd catch also deprecation warnings if they are raised in the tests now.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will also need to add an entry in the what's new file for v0.20 stating the change of behavior in the future.

@@ -242,6 +242,12 @@ def fit(self, X, y, sample_weight=None):
-------
self : object
"""

if self.n_estimators == 'warn':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check and validation should be done in fit instead of __init__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should I change back to n_estimators=10 instead of n_estimators='warn', and then change my if conditional check in the fit() method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no the warn is good, just the test should be in the other place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can refer to: https://github.com/scikit-learn/scikit-learn/pull/11469/files#diff-e6faf37b13574bc591afbf0536128735R864

This is still not merged but we follow this convention: __init__ just assign the parameters to the class attributes and we do checking and validation in the fit method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't lines 245 and 246 above inside the fit() method?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ups sorry it is good there. I good confused with another PR :)



def test_nestimators_future_warning():
# Test that n_estimators future warning is raised. Will be removed in 0.22
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use FIXME: to be removed 0.22

@@ -1228,3 +1229,23 @@ def test_min_impurity_decrease():
# Simply check if the parameter is passed on correctly. Tree tests
# will suffice for the actual working of this param
assert_equal(tree.min_impurity_decrease, 0.1)


def test_nestimators_future_warning():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI:

@pytest.mark.parametrize('forest', [RandomForestClassifier(), RandomForestRegressor(),
                                    ExtraTreesClassifier(), ExtraTreesRegressor(),
                                    RandomTreesEmbedding()])
def test_n_estimators_future_warning(estimator):
    ....
    estimator.fit(X, y)
    ....

@glemaitre glemaitre changed the title Fix to Issue #11128: Create deprecation warning for default n_estimators in RandomForest EHN: Change default n_estimators to 100 for random forest Jul 16, 2018
@glemaitre glemaitre changed the title EHN: Change default n_estimators to 100 for random forest [MRG] EHN: Change default n_estimators to 100 for random forest Jul 16, 2018
@glemaitre
Copy link
Member

FYI: I updated the title of this PR.

@massich
Copy link
Contributor

massich commented Jul 16, 2018

@annaayzenshtat this is a blocker for 0.20 (which we are actively working on right now). If you don't have time to address the comments at this moment that's completely fine. Ping me and I'll take over the PR.

@annaayzenshtat
Copy link
Contributor Author

I'm still working on this issue

@annaayzenshtat
Copy link
Contributor Author

I committed the requested changes. Please take a look at these code changes.

@glemaitre
Copy link
Member

Actually you need to flag the tests with pytest.mark.filterwarnings to avoid raising the future warning in the tests (typically the one that does not set n_estimators)

@annaayzenshtat
Copy link
Contributor Author

Ok, I'll change it.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check this PR as an example how to use pytest

https://github.com/scikit-learn/scikit-learn/pull/11574/files

@annaayzenshtat
Copy link
Contributor Author

I flagged the test with pytest.mark.filterwarnings.

@glemaitre
Copy link
Member

@annaayzenshtat I am helping a bit with the failure that you got and I am filtering the warning because it seems that they are in a lot of places.

Copy link
Member

@amueller amueller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm if tests pass

@annaayzenshtat
Copy link
Contributor Author

Ok, thank you!

@amueller
Copy link
Member

python2.7 test failure :-/

@amueller
Copy link
Member

In SAG?!

@annaayzenshtat
Copy link
Contributor Author

Is there something I'm supposed to do to fix the Python 2.7 failure?

@glemaitre
Copy link
Member

Nop this is some side effect already shown and solve in #11574

@annaayzenshtat
Copy link
Contributor Author

Ok.

@amueller amueller merged commit 2242c59 into scikit-learn:master Jul 17, 2018
@glemaitre
Copy link
Member

@annaayzenshtat Thanks a lot for the contribution.
Feel free to take any other issue ;)

@annaayzenshtat annaayzenshtat deleted the fix/n_estimators_100 branch July 17, 2018 19:49
@annaayzenshtat
Copy link
Contributor Author

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Change default n_estimators in RandomForest (to 100?)
5 participants