Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
09f02c0
implemented perm imp in oob_score for class and reg
robert-robison Oct 7, 2020
a4984a9
Fixed RandomTreesEmbedding compatability error
robert-robison Oct 7, 2020
e46d3e2
Put permutation importance in its own method
robert-robison Oct 8, 2020
965e02a
added tests and fixed random_state and formatting
robert-robison Oct 9, 2020
87fbeae
Add scorer, update example, changed param name
robert-robison Jan 10, 2021
05dab08
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
robert-robison Jan 11, 2021
98a94a6
formatting
robert-robison Jan 11, 2021
93470c9
fixed bug in example
robert-robison Jan 11, 2021
4b1f394
Apply suggestions from code review
robert-robison Jan 12, 2021
dfeaf52
parallelized, removed scoring, fixed tests
robert-robison Jan 12, 2021
d842f14
formatting
robert-robison Jan 12, 2021
99414a5
Add random feature test
robert-robison Jan 12, 2021
2d33c43
formatting
robert-robison Jan 12, 2021
7c7a1cf
Merge remote-tracking branch 'upstream/main' into rf_permutation_impo…
robert-robison Jan 26, 2021
fa4cb3b
remove inspection dependency
robert-robison Jan 26, 2021
f379d00
integrate permutation imp with oob score
robert-robison Jan 28, 2021
a29fa77
MNT refactoring based on further multiprocessing
glemaitre Jan 28, 2021
3143f73
doc
glemaitre Jan 28, 2021
e274b39
ENH parallelize
glemaitre Jan 28, 2021
45ac1b5
doc
glemaitre Jan 28, 2021
a310862
less diff
glemaitre Jan 28, 2021
dec9456
improve doc
glemaitre Jan 28, 2021
9206908
TST check for features importances raised error
glemaitre Jan 28, 2021
7a44e25
reformat tests, update example
robert-robison Jan 29, 2021
d5849d4
DOC rework the example
glemaitre Jan 29, 2021
5252900
glitch
glemaitre Jan 29, 2021
3395309
DOC update user guide
glemaitre Jan 29, 2021
81beb84
DOC solve title marker
glemaitre Jan 29, 2021
8129850
DOC improve example regarding feature importance
glemaitre Jan 29, 2021
5b41963
DOC add new attributes importances_
glemaitre Jan 29, 2021
c885a40
DOC update whats new
glemaitre Jan 29, 2021
4056f8d
update docstring feature_importances_
glemaitre Jan 29, 2021
dbba6cc
TST add test for importances_ attribute
glemaitre Jan 29, 2021
a53ba26
PEP8
glemaitre Jan 29, 2021
005f788
TST improve couple of assert
glemaitre Jan 29, 2021
cd42500
DOC add missing documentation
glemaitre Jan 29, 2021
5f78f3a
clean-up
glemaitre Jan 29, 2021
d741ea8
small fix
glemaitre Jan 29, 2021
1b4a779
DOC use boxplot for all plot in example
glemaitre Jan 29, 2021
d1b0208
EXA solve issue cutted ylabel
glemaitre Jan 29, 2021
20ecb39
DOC add support for sample_weight in OOB score
glemaitre Jan 30, 2021
9102d5f
style code
glemaitre Jan 30, 2021
620e643
DOC add a note regarding correlated feature
glemaitre Jan 30, 2021
3f8216a
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
robert-robison Mar 2, 2021
39b40bb
Attempting to resolve merge in examples
robert-robison Mar 3, 2021
83e61dc
formatting
robert-robison Mar 3, 2021
633b7fe
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
robert-robison Mar 3, 2021
0271b51
formatting
robert-robison Mar 3, 2021
9807344
Merge branch 'main' into rf_permutation_importance
ogrisel Mar 23, 2021
3289a24
Apply suggestions from code review
robert-robison Jun 19, 2021
48c7c7c
Apply additional suggestions from code review
robert-robison Jun 20, 2021
0225547
Apply suggestions from code review
robert-robison Jun 22, 2021
3c404cb
Merge commit '0e7761cdc4f244adb4803f1a97f0a9fe4b365a99' into rf_permu…
robert-robison Jun 22, 2021
a47784f
MAINT Adds target_version to black config (#20293)
thomasjpfan Jun 17, 2021
3695ef2
black formatted changes
robert-robison Jun 22, 2021
86bb5f0
Merge remote-tracking branch 'upstream/main' into rf_permutation_impo…
robert-robison Jun 22, 2021
096abe5
remove old assert_allclose import
robert-robison Jun 22, 2021
be2392e
Apply suggested format changes
robert-robison Jun 22, 2021
d3ddaf8
Reference section edits
robert-robison Jun 22, 2021
280f9d9
Update examples/ensemble/plot_forest_importances.py
robert-robison Jul 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 69 additions & 18 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ amount of time (e.g., on large datasets).
Feature importance evaluation
-----------------------------

Both random-forest and extremely randomized trees estimators provides a fitted
attribute `feature_importances_` giving an estimate of the relative feature
importance. Two strategies are available to estimate the feature importances.
It can be set with the parameter `feature_importances`. The following sections
give information regarding the strategies to estimate the feature importance.

Mean decrease in impurity (MDI)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The relative rank (i.e. depth) of a feature used as a decision node in a
tree can be used to assess the relative importance of that feature with
respect to the predictability of the target variable. Features used at
Expand All @@ -279,18 +288,53 @@ for feature selection. This is known as the mean decrease in impurity, or MDI.
Refer to [L2014]_ for more information on MDI and feature importance
evaluation with Random Forests.

This strategy corresponds to setting `feature_importances="impurity"` which is
the default values.

.. warning::

The impurity-based feature importances computed on tree-based models suffer
from two flaws that can lead to misleading conclusions. First they are
computed on statistics derived from the training dataset and therefore **do
not necessarily inform us on which features are most important to make good
predictions on held-out dataset**. Secondly, **they favor high cardinality
features**, that is features with many unique values.
:ref:`permutation_importance` is an alternative to impurity-based feature
importance that does not suffer from these flaws. These two methods of
obtaining feature importance are explored in:
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`.
from two flaws that can lead to misleading conclusions:

- Firstly, they are computed on statistics derived from the training
dataset and therefore **do not necessarily inform us on which features are
most important to make good predictions on held-out dataset**. [Strobl07]_
- Secondly, they favor **high cardinality features**, that is features with
many unique values. [White94]_

Features importances estimated through feature permutation is an alternative
that does not suffer from these flaws. We give more details regarding this
alternative in the next section.

Permutation feature importances on out-of-bag (OOB) samples
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

An alternative to MDI is the feature importances that uses feature permutation,
referred to as **permutation feature importances**.

Each tree in the ensemble can be evaluated using the out-of-bag samples
[B2001]_. To know the importance of a feature, one can compute the difference
between the tree score with the original OOB sample and an OOB sample for which
the feature of interest will be permuted. Thus, the permutation feature importance
corresponds to the average of the decrease of the tree score.

When a feature has significant predictive power, one expects the score to
decrease. If instead the score remains unchanged, the feature is not important
for predicting the target.

This strategy can be selected by setting
`feature_importances="permutation_oob"`.

.. note::

:ref:`permutation_importance` can also be evaluated on a held-out set by
manually splitting the dataset into a train and a test sets. In this case,
the permutation procedure is applied on the test set rather than on the OOB
samples. The :func:`~sklearn.inspection.permutation_importance` should be
used in this case.

Illustration of using feature importances
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The following example shows a color-coded representation of the relative
importances of each individual pixel for a face recognition task using
Expand All @@ -301,22 +345,29 @@ a :class:`ExtraTreesClassifier` model.
:align: center
:scale: 75

In practice those estimates are stored as an attribute named
``feature_importances_`` on the fitted model. This is an array with shape
``(n_features,)`` whose values are positive and sum to 1.0. The higher
the value, the more important is the contribution of the matching feature
to the prediction function.
MDI and the permutation feature importances are explored in:
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_ensemble_plot_forest_importances_faces.py`
* :ref:`sphx_glr_auto_examples_ensemble_plot_forest_importances.py`
* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`

.. topic:: References

.. [L2014] G. Louppe,
"Understanding Random Forests: From Theory to Practice",
PhD Thesis, U. of Liege, 2014.
.. [Strobl07] `Strobl, C., Boulesteix, AL., Zeileis, A. et al.
Bias in random forest variable importance measures: Illustrations,
sources and a solution.
BMC Bioinformatics 8, 25 (2007).
<https://doi.org/10.1186/1471-2105-8-25>`_
.. [White94] `White, A.P., Liu, W.Z. Technical Note:
Bias in Information-Based Measures in Decision Tree Induction.
Machine Learning 15, 321–329 (1994).
<https://doi.org/10.1023/A:1022694010754>`_
.. [L2014] G. Louppe,
"Understanding Random Forests: From Theory to Practice",
PhD Thesis, U. of Liege, 2014.

.. _random_trees_embedding:

Expand Down Expand Up @@ -624,7 +675,7 @@ chapter on gradient boosting in [F2001]_ and is related to the parameter
``interaction.depth`` in R's gbm package where ``max_leaf_nodes == interaction.depth + 1`` .

Mathematical formulation
-------------------------
------------------------

We first present GBRT for regression, and then detail the classification
case.
Expand Down
21 changes: 21 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,27 @@ Changelog
target. Additional private refactoring was performed.
:pr:`19162` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Feature| Implement out-of-bag permutation feature importances by setting
the parameter `feature_importances="permutation_oob"` in
:class:`ensemble.RandomForestClassifier`,
:class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier`, and
:class:`ensemble.ExtraTreesRegressor`.
:pr:`18603` by :user:`Robert Robison <robert-robison>`.

- |Feature| A new fitted attribute `importances_` has been introduced reporting
the impurity-based or permutation feature importances. This attribute is a
:class:`~sklearn.utils.Bunch` storing the raw, averaged, and variations of
the importances across all trees of the forest.
:pr:`18603` by :user:`Robert Robison <robert-robison>`.

- |Enhancement| OOB score reported in
:class:`ensemble.RandomForestClassifier`,
:class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier`, and
:class:`ensemble.ExtraTreesRegressor` is taking into account `sample_weight`
while ignoring it previously.
:pr:`18603` by :user:`Robert Robison <robert-robison>`.
- |Enhancement| :class:`~sklearn.ensemble.HistGradientBoostingClassifier` and
:class:`~sklearn.ensemble.HistGradientBoostingRegressor` are no longer
experimental. They are now considered stable and are subject to the same
Expand Down
109 changes: 65 additions & 44 deletions examples/ensemble/plot_forest_importances.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,52 @@
==========================================

This example shows the use of a forest of trees to evaluate the importance of
features on an artificial classification task. The blue bars are the feature
importances of the forest, along with their inter-trees variability represented
by the error bars.
features on an artificial classification task.

As expected, the plot suggests that 3 features are informative, while the
remaining are not.
We show two strategies to estimate the feature importances: (i) the
impurity-based feature importances and (ii) the permutation feature
importances on out-of-bag (OOB) samples.

.. warning::
Impurity-based feature importances can be misleading for high cardinality
features (many unique values). Check the documentation of the
`feature_importances` parameter to have more details regarding the
alternative as the permutation feature importances.
"""
print(__doc__)
import matplotlib.pyplot as plt

# %%
# Data generation and model fitting
# ---------------------------------
# We generate a synthetic dataset with only 3 informative features. We will
# explicitly not shuffle the dataset to ensure that the informative features
# will correspond to the three first columns of X. In addition, we will split
# our dataset into training and testing subsets.
# explicitely not shuffle the dataset to ensure that the informative features
# correspond to the three first columns of `X`.
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

X, y = make_classification(
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, n_classes=2, random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=42)
n_samples=1000,
n_features=10,
n_informative=3,
n_redundant=0,
n_repeated=0,
n_classes=2,
random_state=0,
shuffle=False,
)
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)

# %%
# A random forest classifier will be fitted to compute the feature importances.
from sklearn.ensemble import RandomForestClassifier

feature_names = [f'feature {i}' for i in range(X.shape[1])]
forest = RandomForestClassifier(random_state=0)
forest.fit(X_train, y_train)
feature_names = [f"feature {i}" for i in range(X.shape[1])]

# MDI-based feature importance ("impurity") is the default
forest = RandomForestClassifier(feature_importances="impurity", random_state=0)

# %%
# Feature importance based on mean decrease in impurity
# -----------------------------------------------------
# Feature importance based on Mean Decrease in Impurity (MDI)
# -----------------------------------------------------------
# Feature importances are provided by the fitted attribute
# `feature_importances_` and they are computed as the mean and standard
# deviation of accumulation of the impurity decrease within each tree.
Expand All @@ -53,17 +62,19 @@
import numpy as np

start_time = time.time()
forest.fit(X_train, y_train)
importances = forest.feature_importances_
std = np.std([
tree.feature_importances_ for tree in forest.estimators_], axis=0)
std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
elapsed_time = time.time() - start_time

print(f"Elapsed time to compute the importances: "
f"{elapsed_time:.3f} seconds")
print(
f"Elapsed time to train and compute the importances: " f"{elapsed_time:.3f} seconds"
)

# %%
# Let's plot the impurity-based importance.
import pandas as pd

forest_importances = pd.Series(importances, index=feature_names)

fig, ax = plt.subplots()
Expand All @@ -75,36 +86,46 @@
# %%
# We observe that, as expected, the three first features are found important.
#
# Feature importance based on feature permutation
# -----------------------------------------------
# Permutation feature importance overcomes limitations of the impurity-based
# feature importance: they do not have a bias toward high-cardinality features
# and can be computed on a left-out test set.
from sklearn.inspection import permutation_importance

# Permutation Feature Importances on OOB samples
# ----------------------------------------------
# We will an alternative to the impurity-based feature importances based on
# feature permutation using the OOB samples. We fit a new random-forest where
# we explicitely specify to compute the permutation feature importances on OOB.
feature_names = [f"feature {i}" for i in range(X.shape[1])]
forest = RandomForestClassifier(feature_importances="permutation_oob", random_state=0)
start_time = time.time()
result = permutation_importance(
forest, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2)
forest.fit(X_train, y_train)
elapsed_time = time.time() - start_time
print(f"Elapsed time to compute the importances: "
f"{elapsed_time:.3f} seconds")

forest_importances = pd.Series(result.importances_mean, index=feature_names)
print(
f"Elapsed time to train and compute the importances: " f"{elapsed_time:.3f} seconds"
)

# %%
# The computation for full permutation importance is more costly. Features are
# shuffled n times and the model refitted to estimate the importance of it.
# Please see :ref:`permutation_importance` for more details. We can now plot
# the importance ranking.
forest_importances = pd.Series(forest.feature_importances_, index=feature_names)

# %%
# The permutation importance is more computationally costly. Indeed, it
# requires to fit the tree and to make additional processing: each tree will
# be evaluated on its OOB sample as well as an OOB sample where features will
# be permuted. This step is costly and explains the time fitting difference
# between of the two forests.
#
# We now plot the feature importance ranking.
fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
forest_importances.plot.bar(yerr=forest.importances_.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()

# %%
# The same features are detected as most important using both methods. Although
# the relative importances vary. As seen on the plots, MDI is less likely than
# permutation importance to fully omit a feature.
# As in the impurity-based case, the three most important features are detected.
# We see that non-important features have a mean decrease accuracy of zeros.
# Hence, permuting these features did not have an impact on the score.
#
# Another difference between the two feature importances is the scale of the
# reported values:
# - the permutation feature importances are not normalized and simply
# correspond to a difference of scores;
# - the impurity-based feature importances reported are normalized so that
# they sum to 1.
Loading