Skip to content

Commit 2242c59

Browse files
annaayzenshtatamueller
authored andcommitted
[MRG] EHN: Change default n_estimators to 100 for random forest (#11542)
<!-- Thanks for contributing a pull request! Please ensure you have taken a look at the contribution guidelines: https://github.com/scikit-learn/scikit-learn/blob/master/CONTRIBUTING.md#pull-request-checklist --> #### Reference Issues/PRs <!-- Example: Fixes #1234. See also #3456. Please use keywords (e.g., Fixes) to create link to the issues or pull requests you resolved, so that they will automatically be closed when your pull request is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests --> Fixes #11128. #### What does this implement/fix? Explain your changes. Issues deprecation warning message for the default n_estimators parameter for the forest classifiers. Test added for the warning message when the default parameter is used. #### Any other comments? <!-- Please be aware that we are a loose team of volunteers so patience is necessary; assistance handling other issues is very welcome. We value all user contributions, no matter how minor they are. If we are slow to review, either the pull request needs some benchmarking, tinkering, convincing, etc. or more likely the reviewers are simply busy. In either case, we ask for your understanding during the review process. For more information, see our FAQ on this topic: http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention. Thanks for contributing! -->
1 parent a496491 commit 2242c59

File tree

13 files changed

+119
-19
lines changed

13 files changed

+119
-19
lines changed

doc/whats_new/v0.20.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ Highlights
2222
We have tried to improve our support for common data-science use-cases
2323
including missing values, categorical variables, heterogeneous data, and
2424
features/targets with unusual distributions.
25-
2625
Missing values in features, represented by NaNs, are now accepted in
2726
column-wise preprocessing such as scalers. Each feature is fitted disregarding
2827
NaNs, and data containing NaNs can be transformed. The new :mod:`impute`
@@ -734,6 +733,15 @@ Datasets
734733
API changes summary
735734
-------------------
736735

736+
Classifiers and regressors
737+
738+
- The default value of the ``n_estimators`` parameter of
739+
:class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`,
740+
:class:`ensemble.ExtraTreesClassifier`, :class:`ensemble.ExtraTreesRegressor`,
741+
and :class:`ensemble.RandomTreesEmbedding` will change from 10 in version 0.20
742+
to 100 in 0.22. A FutureWarning is raised when the default value is used.
743+
:issue:`11542` by :user:`Anna Ayzenshtat <annaayzenshtat>`.
744+
737745
Linear, kernelized and related models
738746

739747
- Deprecate ``random_state`` parameter in :class:`svm.OneClassSVM` as the

examples/applications/plot_prediction_latency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def plot_benchmark_throughput(throughputs, configuration):
285285
'complexity_label': 'non-zero coefficients',
286286
'complexity_computer': lambda clf: np.count_nonzero(clf.coef_)},
287287
{'name': 'RandomForest',
288-
'instance': RandomForestRegressor(),
288+
'instance': RandomForestRegressor(n_estimators=100),
289289
'complexity_label': 'estimators',
290290
'complexity_computer': lambda clf: clf.n_estimators},
291291
{'name': 'SVR',

examples/ensemble/plot_ensemble_oob.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,18 @@
4545
# error trajectory during training.
4646
ensemble_clfs = [
4747
("RandomForestClassifier, max_features='sqrt'",
48-
RandomForestClassifier(warm_start=True, oob_score=True,
48+
RandomForestClassifier(n_estimators=100,
49+
warm_start=True, oob_score=True,
4950
max_features="sqrt",
5051
random_state=RANDOM_STATE)),
5152
("RandomForestClassifier, max_features='log2'",
52-
RandomForestClassifier(warm_start=True, max_features='log2',
53+
RandomForestClassifier(n_estimators=100,
54+
warm_start=True, max_features='log2',
5355
oob_score=True,
5456
random_state=RANDOM_STATE)),
5557
("RandomForestClassifier, max_features=None",
56-
RandomForestClassifier(warm_start=True, max_features=None,
58+
RandomForestClassifier(n_estimators=100,
59+
warm_start=True, max_features=None,
5760
oob_score=True,
5861
random_state=RANDOM_STATE))
5962
]

examples/ensemble/plot_random_forest_regression_multioutput.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@
4343
X, y, train_size=400, test_size=200, random_state=4)
4444

4545
max_depth = 30
46-
regr_multirf = MultiOutputRegressor(RandomForestRegressor(max_depth=max_depth,
46+
regr_multirf = MultiOutputRegressor(RandomForestRegressor(n_estimators=100,
47+
max_depth=max_depth,
4748
random_state=0))
4849
regr_multirf.fit(X_train, y_train)
4950

50-
regr_rf = RandomForestRegressor(max_depth=max_depth, random_state=2)
51+
regr_rf = RandomForestRegressor(n_estimators=100, max_depth=max_depth,
52+
random_state=2)
5153
regr_rf.fit(X_train, y_train)
5254

5355
# Predict on new data

examples/ensemble/plot_voting_probas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sklearn.ensemble import VotingClassifier
3131

3232
clf1 = LogisticRegression(random_state=123)
33-
clf2 = RandomForestClassifier(random_state=123)
33+
clf2 = RandomForestClassifier(n_estimators=100, random_state=123)
3434
clf3 = GaussianNB()
3535
X = np.array([[-1.0, -1.0], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]])
3636
y = np.array([1, 1, 2, 2])

sklearn/ensemble/forest.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class BaseForest(six.with_metaclass(ABCMeta, BaseEnsemble)):
135135
@abstractmethod
136136
def __init__(self,
137137
base_estimator,
138-
n_estimators=10,
138+
n_estimators=100,
139139
estimator_params=tuple(),
140140
bootstrap=False,
141141
oob_score=False,
@@ -242,6 +242,12 @@ def fit(self, X, y, sample_weight=None):
242242
-------
243243
self : object
244244
"""
245+
246+
if self.n_estimators == 'warn':
247+
warnings.warn("The default value of n_estimators will change from "
248+
"10 in version 0.20 to 100 in 0.22.", FutureWarning)
249+
self.n_estimators = 10
250+
245251
# Validate or convert input data
246252
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
247253
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
@@ -400,7 +406,7 @@ class ForestClassifier(six.with_metaclass(ABCMeta, BaseForest,
400406
@abstractmethod
401407
def __init__(self,
402408
base_estimator,
403-
n_estimators=10,
409+
n_estimators=100,
404410
estimator_params=tuple(),
405411
bootstrap=False,
406412
oob_score=False,
@@ -409,7 +415,6 @@ def __init__(self,
409415
verbose=0,
410416
warm_start=False,
411417
class_weight=None):
412-
413418
super(ForestClassifier, self).__init__(
414419
base_estimator,
415420
n_estimators=n_estimators,
@@ -640,7 +645,7 @@ class ForestRegressor(six.with_metaclass(ABCMeta, BaseForest, RegressorMixin)):
640645
@abstractmethod
641646
def __init__(self,
642647
base_estimator,
643-
n_estimators=10,
648+
n_estimators=100,
644649
estimator_params=tuple(),
645650
bootstrap=False,
646651
oob_score=False,
@@ -760,6 +765,10 @@ class RandomForestClassifier(ForestClassifier):
760765
n_estimators : integer, optional (default=10)
761766
The number of trees in the forest.
762767
768+
.. versionchanged:: 0.20
769+
The default value of ``n_estimators`` will change from 10 in
770+
version 0.20 to 100 in version 0.22.
771+
763772
criterion : string, optional (default="gini")
764773
The function to measure the quality of a split. Supported criteria are
765774
"gini" for the Gini impurity and "entropy" for the information gain.
@@ -973,7 +982,7 @@ class labels (multi-output problem).
973982
DecisionTreeClassifier, ExtraTreesClassifier
974983
"""
975984
def __init__(self,
976-
n_estimators=10,
985+
n_estimators='warn',
977986
criterion="gini",
978987
max_depth=None,
979988
min_samples_split=2,
@@ -1034,6 +1043,10 @@ class RandomForestRegressor(ForestRegressor):
10341043
n_estimators : integer, optional (default=10)
10351044
The number of trees in the forest.
10361045
1046+
.. versionchanged:: 0.20
1047+
The default value of ``n_estimators`` will change from 10 in
1048+
version 0.20 to 100 in version 0.22.
1049+
10371050
criterion : string, optional (default="mse")
10381051
The function to measure the quality of a split. Supported criteria
10391052
are "mse" for the mean squared error, which is equal to variance
@@ -1213,7 +1226,7 @@ class RandomForestRegressor(ForestRegressor):
12131226
DecisionTreeRegressor, ExtraTreesRegressor
12141227
"""
12151228
def __init__(self,
1216-
n_estimators=10,
1229+
n_estimators='warn',
12171230
criterion="mse",
12181231
max_depth=None,
12191232
min_samples_split=2,
@@ -1270,6 +1283,10 @@ class ExtraTreesClassifier(ForestClassifier):
12701283
n_estimators : integer, optional (default=10)
12711284
The number of trees in the forest.
12721285
1286+
.. versionchanged:: 0.20
1287+
The default value of ``n_estimators`` will change from 10 in
1288+
version 0.20 to 100 in version 0.22.
1289+
12731290
criterion : string, optional (default="gini")
12741291
The function to measure the quality of a split. Supported criteria are
12751292
"gini" for the Gini impurity and "entropy" for the information gain.
@@ -1456,7 +1473,7 @@ class labels (multi-output problem).
14561473
splits.
14571474
"""
14581475
def __init__(self,
1459-
n_estimators=10,
1476+
n_estimators='warn',
14601477
criterion="gini",
14611478
max_depth=None,
14621479
min_samples_split=2,
@@ -1515,6 +1532,10 @@ class ExtraTreesRegressor(ForestRegressor):
15151532
n_estimators : integer, optional (default=10)
15161533
The number of trees in the forest.
15171534
1535+
.. versionchanged:: 0.20
1536+
The default value of ``n_estimators`` will change from 10 in
1537+
version 0.20 to 100 in version 0.22.
1538+
15181539
criterion : string, optional (default="mse")
15191540
The function to measure the quality of a split. Supported criteria
15201541
are "mse" for the mean squared error, which is equal to variance
@@ -1668,7 +1689,7 @@ class ExtraTreesRegressor(ForestRegressor):
16681689
RandomForestRegressor: Ensemble regressor using trees with optimal splits.
16691690
"""
16701691
def __init__(self,
1671-
n_estimators=10,
1692+
n_estimators='warn',
16721693
criterion="mse",
16731694
max_depth=None,
16741695
min_samples_split=2,
@@ -1730,6 +1751,10 @@ class RandomTreesEmbedding(BaseForest):
17301751
n_estimators : integer, optional (default=10)
17311752
Number of trees in the forest.
17321753
1754+
.. versionchanged:: 0.20
1755+
The default value of ``n_estimators`` will change from 10 in
1756+
version 0.20 to 100 in version 0.22.
1757+
17331758
max_depth : integer, optional (default=5)
17341759
The maximum depth of each tree. If None, then nodes are expanded until
17351760
all leaves are pure or until all leaves contain less than
@@ -1832,7 +1857,7 @@ class RandomTreesEmbedding(BaseForest):
18321857
"""
18331858

18341859
def __init__(self,
1835-
n_estimators=10,
1860+
n_estimators='warn',
18361861
max_depth=5,
18371862
min_samples_split=2,
18381863
min_samples_leaf=1,

sklearn/ensemble/tests/test_forest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sklearn.utils.testing import assert_raises
3232
from sklearn.utils.testing import assert_warns
3333
from sklearn.utils.testing import assert_warns_message
34+
from sklearn.utils.testing import assert_no_warnings
3435
from sklearn.utils.testing import ignore_warnings
3536

3637
from sklearn import datasets
@@ -186,6 +187,7 @@ def check_regressor_attributes(name):
186187
assert_false(hasattr(r, "n_classes_"))
187188

188189

190+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
189191
@pytest.mark.parametrize('name', FOREST_REGRESSORS)
190192
def test_regressor_attributes(name):
191193
check_regressor_attributes(name)
@@ -432,6 +434,7 @@ def check_oob_score_raise_error(name):
432434
bootstrap=False).fit, X, y)
433435

434436

437+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
435438
@pytest.mark.parametrize('name', FOREST_ESTIMATORS)
436439
def test_oob_score_raise_error(name):
437440
check_oob_score_raise_error(name)
@@ -489,6 +492,7 @@ def check_pickle(name, X, y):
489492
assert_equal(score, score2)
490493

491494

495+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
492496
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS)
493497
def test_pickle(name):
494498
if name in FOREST_CLASSIFIERS:
@@ -526,6 +530,7 @@ def check_multioutput(name):
526530
assert_equal(log_proba[1].shape, (4, 4))
527531

528532

533+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
529534
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS)
530535
def test_multioutput(name):
531536
check_multioutput(name)
@@ -549,6 +554,7 @@ def check_classes_shape(name):
549554
assert_array_equal(clf.classes_, [[-1, 1], [-2, 2]])
550555

551556

557+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
552558
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
553559
def test_classes_shape(name):
554560
check_classes_shape(name)
@@ -738,6 +744,7 @@ def check_min_samples_split(name):
738744
"Failed with {0}".format(name))
739745

740746

747+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
741748
@pytest.mark.parametrize('name', FOREST_ESTIMATORS)
742749
def test_min_samples_split(name):
743750
check_min_samples_split(name)
@@ -775,6 +782,7 @@ def check_min_samples_leaf(name):
775782
"Failed with {0}".format(name))
776783

777784

785+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
778786
@pytest.mark.parametrize('name', FOREST_ESTIMATORS)
779787
def test_min_samples_leaf(name):
780788
check_min_samples_leaf(name)
@@ -842,6 +850,7 @@ def check_sparse_input(name, X, X_sparse, y):
842850
dense.fit_transform(X).toarray())
843851

844852

853+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
845854
@pytest.mark.parametrize('name', FOREST_ESTIMATORS)
846855
@pytest.mark.parametrize('sparse_matrix',
847856
(csr_matrix, csc_matrix, coo_matrix))
@@ -899,6 +908,7 @@ def check_memory_layout(name, dtype):
899908
assert_array_almost_equal(est.fit(X, y).predict(X), y)
900909

901910

911+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
902912
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS)
903913
@pytest.mark.parametrize('dtype', (np.float64, np.float32))
904914
def test_memory_layout(name, dtype):
@@ -977,6 +987,7 @@ def check_class_weights(name):
977987
clf.fit(iris.data, iris.target, sample_weight=sample_weight)
978988

979989

990+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
980991
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
981992
def test_class_weights(name):
982993
check_class_weights(name)
@@ -996,6 +1007,7 @@ def check_class_weight_balanced_and_bootstrap_multi_output(name):
9961007
clf.fit(X, _y)
9971008

9981009

1010+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
9991011
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
10001012
def test_class_weight_balanced_and_bootstrap_multi_output(name):
10011013
check_class_weight_balanced_and_bootstrap_multi_output(name)
@@ -1026,6 +1038,7 @@ def check_class_weight_errors(name):
10261038
assert_raises(ValueError, clf.fit, X, _y)
10271039

10281040

1041+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
10291042
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
10301043
def test_class_weight_errors(name):
10311044
check_class_weight_errors(name)
@@ -1163,6 +1176,7 @@ def test_warm_start_oob(name):
11631176
check_warm_start_oob(name)
11641177

11651178

1179+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
11661180
def test_dtype_convert(n_classes=15):
11671181
classifier = RandomForestClassifier(random_state=0, bootstrap=False)
11681182

@@ -1201,6 +1215,7 @@ def test_decision_path(name):
12011215
check_decision_path(name)
12021216

12031217

1218+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
12041219
def test_min_impurity_split():
12051220
# Test if min_impurity_split of base estimators is set
12061221
# Regression test for #8006
@@ -1216,6 +1231,7 @@ def test_min_impurity_split():
12161231
assert_equal(tree.min_impurity_split, 0.1)
12171232

12181233

1234+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
12191235
def test_min_impurity_decrease():
12201236
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
12211237
all_estimators = [RandomForestClassifier, RandomForestRegressor,
@@ -1228,3 +1244,21 @@ def test_min_impurity_decrease():
12281244
# Simply check if the parameter is passed on correctly. Tree tests
12291245
# will suffice for the actual working of this param
12301246
assert_equal(tree.min_impurity_decrease, 0.1)
1247+
1248+
1249+
@pytest.mark.parametrize('forest',
1250+
[RandomForestClassifier, RandomForestRegressor,
1251+
ExtraTreesClassifier, ExtraTreesRegressor,
1252+
RandomTreesEmbedding])
1253+
def test_nestimators_future_warning(forest):
1254+
# FIXME: to be removed 0.22
1255+
1256+
# When n_estimators default value is used
1257+
msg_future = ("The default value of n_estimators will change from "
1258+
"10 in version 0.20 to 100 in 0.22.")
1259+
est = forest()
1260+
est = assert_warns_message(FutureWarning, msg_future, est.fit, X, y)
1261+
1262+
# When n_estimators is a valid value not equal to the default
1263+
est = forest(n_estimators=100)
1264+
est = assert_no_warnings(est.fit, X, y)

0 commit comments

Comments
 (0)