Skip to content

DOC Add warm start section for tree ensembles #29001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,43 @@ estimation.
representations of feature space, also these approaches focus also on
dimensionality reduction.

.. _tree_ensemble_warm_start:

Fitting additional trees
------------------------

RandomForest, Extra-Trees and :class:`RandomTreesEmbedding` estimators all support
``warm_start=True`` which allows you to add more trees to an already fitted model.

::

>>> from sklearn.datasets import make_classification
>>> from sklearn.ensemble import RandomForestClassifier

>>> X, y = make_classification(n_samples=100, random_state=1)
>>> clf = RandomForestClassifier(n_estimators=10)
>>> clf = clf.fit(X, y) # fit with 10 trees
>>> len(clf.estimators_)
10
>>> # set warm_start and increase num of estimators
>>> _ = clf.set_params(n_estimators=20, warm_start=True)
>>> _ = clf.fit(X, y) # fit additional 10 trees
>>> len(clf.estimators_)
20

When ``random_state`` is also set, the internal random state is also preserved
between ``fit`` calls. This means that training a model once with ``n`` estimators is
the same as building the model iteratively via multiple ``fit`` calls, where the
final number of estimators is equal to ``n``.

::

>>> clf = RandomForestClassifier(n_estimators=20) # set `n_estimators` to 10 + 10
>>> _ = clf.fit(X, y) # fit `estimators_` will be the same as `clf` above

Note that this differs from the usual behavior of :term:`random_state` in that it does
*not* result in the same result across different calls.

.. _bagging:

Bagging meta-estimator
Expand Down
10 changes: 5 additions & 5 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ class RandomForestClassifier(ForestClassifier):
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`Glossary <warm_start>` and
:ref:`gradient_boosting_warm_start` for details.
:ref:`tree_ensemble_warm_start` for details.

class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \
default=None
Expand Down Expand Up @@ -1710,7 +1710,7 @@ class RandomForestRegressor(ForestRegressor):
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`Glossary <warm_start>` and
:ref:`gradient_boosting_warm_start` for details.
:ref:`tree_ensemble_warm_start` for details.

ccp_alpha : non-negative float, default=0.0
Complexity parameter used for Minimal Cost-Complexity Pruning. The
Expand Down Expand Up @@ -2049,7 +2049,7 @@ class ExtraTreesClassifier(ForestClassifier):
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`Glossary <warm_start>` and
:ref:`gradient_boosting_warm_start` for details.
:ref:`tree_ensemble_warm_start` for details.

class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \
default=None
Expand Down Expand Up @@ -2434,7 +2434,7 @@ class ExtraTreesRegressor(ForestRegressor):
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`Glossary <warm_start>` and
:ref:`gradient_boosting_warm_start` for details.
:ref:`tree_ensemble_warm_start` for details.

ccp_alpha : non-negative float, default=0.0
Complexity parameter used for Minimal Cost-Complexity Pruning. The
Expand Down Expand Up @@ -2727,7 +2727,7 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest):
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble, otherwise, just fit a whole
new forest. See :term:`Glossary <warm_start>` and
:ref:`gradient_boosting_warm_start` for details.
:ref:`tree_ensemble_warm_start` for details.

Attributes
----------
Expand Down
Loading