Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
c0e22ea
First working implementation of UFI, does not support multi output, h…
GaetandeCast Apr 14, 2025
b1e9df8
Removed the normalization inherited from the old MDI to avoid instabi…
GaetandeCast Apr 15, 2025
2a694b6
added multi output support
GaetandeCast Apr 15, 2025
fd0abfb
removed redundant cross_impurity computations
GaetandeCast Apr 15, 2025
ef9f48d
added mdi_oob
GaetandeCast Apr 16, 2025
a225a42
redesigned ufi for better memory management
GaetandeCast Apr 17, 2025
83f3880
removed a debug import
GaetandeCast Apr 17, 2025
27618db
added mdi_oob, cleaned the code
GaetandeCast Apr 18, 2025
5ad9636
better unified the code between ufi and mdi_oob
GaetandeCast Apr 18, 2025
21d2e04
fixed a call oversight
GaetandeCast Apr 18, 2025
8194d6e
fixed an error in mdi_oob computations
GaetandeCast Apr 18, 2025
9e16a09
changed tests on feature_importances_ to use unbiased FI too
GaetandeCast Apr 22, 2025
8991d79
add tests to check that the added methods coincide with the papers an…
GaetandeCast Apr 23, 2025
a9d2983
added support for regression (only MSE split)
GaetandeCast Apr 24, 2025
710d42c
added warning for unbiased feature importance in classification witho…
GaetandeCast Apr 24, 2025
ddedf27
merged test_non_OOB_unbiased_feature_importances_class & _reg
GaetandeCast Apr 24, 2025
1de98fc
Fixed a few mistake so that ufi-regression matches feature_importance…
GaetandeCast Apr 25, 2025
c7c5d76
Extended the tests on matching the paper values to regression
GaetandeCast Apr 25, 2025
a44084d
re added tests on oob_score for dense X. They fail
GaetandeCast Apr 25, 2025
082206c
revert a small change to a test
GaetandeCast Apr 28, 2025
b028cb9
raise an error when calling unbiased feature importance with criterio…
GaetandeCast Apr 28, 2025
dcb3106
adapted the tests to the previous commit
GaetandeCast Apr 29, 2025
c61c8dc
Added log_loss ufi
GaetandeCast Apr 29, 2025
d198f20
fixed the oob_score_ issue, simplified the self.value accesses
GaetandeCast Apr 29, 2025
f2acf5f
updated api and tests for ufi with 'log_loss'
GaetandeCast Apr 30, 2025
f41cf3f
divide by 2 ufi 'log_loss' and improve tests
GaetandeCast Apr 30, 2025
af785d6
fix some linting
GaetandeCast Apr 30, 2025
ccd4f18
fixed Cython linting
GaetandeCast Apr 30, 2025
ac36aaa
added inline function for clarity and comments on available criteria
GaetandeCast Apr 30, 2025
fda8349
Merge branch 'main' into main
ogrisel Apr 30, 2025
5f1beed
Merge branch 'main' into unbiased-feature-importance
GaetandeCast May 7, 2025
f10721e
add sample weight support
GaetandeCast May 7, 2025
6966147
add test reg mse 1hot == classi gini
GaetandeCast May 9, 2025
50d47a8
fix bug in previous commit and simplify test
GaetandeCast May 9, 2025
ce52159
add support for methods in gradient boosting, when
GaetandeCast May 12, 2025
2b9099b
support and test degenerate case (ensemble of single node trees)
GaetandeCast May 12, 2025
d54cf0f
improve sample weight support and test
GaetandeCast May 12, 2025
8a09b39
move gradient boosting changes to gradient-boosting branch
GaetandeCast May 14, 2025
9fbebb5
finish removing gradient-boosting changes
GaetandeCast May 14, 2025
6fcf61c
move previous sample weight test to tree level
GaetandeCast May 14, 2025
241de66
add sample weight tests
GaetandeCast May 14, 2025
4c40c43
add convergence test between biased and unbiased fi
GaetandeCast May 14, 2025
8329b3b
add support for scipy sparse matrices
GaetandeCast May 15, 2025
0b48af4
update importances test to test with sparse data
GaetandeCast May 15, 2025
229cc4d
Update doc example on permutation and MDI importance
GaetandeCast May 16, 2025
6e4703b
remove unused code
GaetandeCast May 21, 2025
e8b06dc
improve test_importances
GaetandeCast May 21, 2025
452019d
remove unnecessary validation on X
GaetandeCast May 22, 2025
6fd9304
make functions private to avoid docstring test fail
GaetandeCast May 22, 2025
bbbfb38
add TODO on the public aspect of the method
GaetandeCast May 22, 2025
77bc017
add sparse conversion to csr that was done by the removed validate_X_…
GaetandeCast May 22, 2025
a77fd2e
changed tree oob sample weight test to be more simple and understandable
GaetandeCast May 22, 2025
dd59f9c
Update sklearn/tree/tests/test_tree.py
GaetandeCast May 22, 2025
8ba000e
add global_random_seed to match_paper tests
GaetandeCast May 23, 2025
9a01dbe
add skip if joblib version <1.2, remove sample weight test that was m…
GaetandeCast May 23, 2025
0c5cab5
fix regex match
GaetandeCast May 23, 2025
9e78f6d
drop the return_as parameter, remove joblib version skip in tests
GaetandeCast May 23, 2025
77676e7
add non normalized feature importance in private attribute
GaetandeCast May 27, 2025
474839f
add changelog entries for tree and ensemble
GaetandeCast May 30, 2025
ce9ebe0
fix coverage warnings
GaetandeCast Jun 2, 2025
be7c99d
divide importances by weighted_n_sample to avoid large unnormalized v…
GaetandeCast Jun 4, 2025
41729c6
Merge branch 'main' into unbiased-feature-importance
GaetandeCast Jun 6, 2025
88c30eb
remove mdi_oob, remove normalization for ufi
GaetandeCast Jun 6, 2025
eb3316b
made ufi only available with gini, mse and friedman_mse
GaetandeCast Jun 10, 2025
4d39721
Apply suggestions from code review
GaetandeCast Jun 12, 2025
a66d8e1
Apply suggestions from code review
GaetandeCast Jun 12, 2025
137645b
fix linting and a typo
GaetandeCast Jun 12, 2025
f99b8b7
divide ufi regression by 2, add private unnormalised mdi, test unnorm…
GaetandeCast Jun 12, 2025
2322b26
make the method public, reorder cython code
GaetandeCast Jun 17, 2025
1759c87
simplify docstring and fix test in tree
GaetandeCast Jun 17, 2025
a7821f6
update changelog
GaetandeCast Jun 17, 2025
2f9afc5
Add paragraph in tree user guide
GaetandeCast Jun 18, 2025
abe0493
refactor doc example
GaetandeCast Jun 18, 2025
6157f08
fix docstring on scoring function
GaetandeCast Jun 19, 2025
7ef1da6
call unnormalized fi only once
GaetandeCast Jun 19, 2025
227ec3d
Apply suggestions from code review
GaetandeCast Jun 19, 2025
83c0a1a
remove unused code for oob_pred, clean aggregation of oob_pred
GaetandeCast Jun 19, 2025
e0313d3
improve references in docstrings and add versionadded
GaetandeCast Jun 19, 2025
cbfc53a
mention the new method in the class attribute docstring
GaetandeCast Jun 19, 2025
f4ac8e2
return_as generator and shorten method names
GaetandeCast Jun 25, 2025
dfbb104
raise error when calling ufi with oob_score=False and corresponding test
GaetandeCast Jun 25, 2025
7ea88ec
Merge branch 'main' into unbiased-feature-importance
ogrisel Sep 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- Forest estimators such as :class:`ensemble.RandomForestClassifier` and
:class:`ensemble.ExtraTreesRegressor` now have a new attribute
for unbiased impurity feature importance: `unbiased_feature_importances_`
This method leverages out-of-bag samples to correct the known bias of MDI
importance. It is automatically computed during training when
`oob_score=True`.
By :user:`Gaétan de Castellane <GaetandeCast>`.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :class:`tree.Tree` now has a method that allows passing test samples
to compute a test score and feature importance measure.
The public method `compute_unbiased_feature_importance` is used in forest
estimators to compute out-of-bag predictions and unbiased feature importance
measures but can be used in any tree ensemble.
By :user:`Gaétan de Castellane <GaetandeCast>`.
179 changes: 99 additions & 80 deletions examples/inspection/plot_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,32 @@
Permutation Importance vs Random Forest Feature Importance (MDI)
================================================================

In this example, we will compare the impurity-based feature importance of
:class:`~sklearn.ensemble.RandomForestClassifier` with the
permutation importance on the titanic dataset using
:func:`~sklearn.inspection.permutation_importance`. We will show that the
impurity-based feature importance can inflate the importance of numerical
features.

Furthermore, the impurity-based feature importance of random forests suffers
from being computed on statistics derived from the training dataset: the
importances can be high even for features that are not predictive of the target
variable, as long as the model has the capacity to use them to overfit.

This example shows how to use Permutation Importances as an alternative that
can mitigate those limitations.
In this example, we will show on the titanic dataset how the impurity-based feature
importance (MDI, introduced by Breiman in [RF2001]_) of
:class:`~sklearn.ensemble.RandomForestClassifier` can give misleading results by
favoring high-cardinality features and we will give two alternatives to avoid the
issue.

In a nutshell, the impurity-based feature importance of random forests suffers from
being computed on statistics derived from the training dataset: the importances can be
high even for features that are not predictive of the target variable, as long as the
model has the capacity to use them to overfit. The effect is stronger the more unique
values the feature takes.

A first solution is to use :func:`~sklearn.inspection.permutation_importance` on test
data instead. Although this method is slower, it is not restricted to random forests and
does not suffer from the bias of MDI.

Another solution is to use the `unbiased_feature_importances_` attribute of random
forests, which leverages out-of-bag samples to correct the aforementioned bias. This
method was introduced by Li et al. in [UFI2020]_ and uses the samples that were not used
in the construction of each tree of the forest to modify the MDI.

.. rubric:: References

* :doi:`L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
2001. <10.1023/A:1010933404324>`
.. [RF2001] :doi:`"Random Forests" <10.1023/A:1010933404324>` L. Breiman, 2001
.. [UFI2020] :doi:`"Unbiased Measurement of Feature Importance in Tree-Based Methods"
<10.1145/3429445>` Zhengze Zhou, Giles Hooker, 2020

"""

Expand Down Expand Up @@ -87,7 +94,7 @@
rf = Pipeline(
[
("preprocess", preprocessing),
("classifier", RandomForestClassifier(random_state=42)),
("classifier", RandomForestClassifier(random_state=42, oob_score=True)),
]
)
rf.fit(X_train, y_train)
Expand All @@ -98,9 +105,16 @@
# Before inspecting the feature importances, it is important to check that
# the model predictive performance is high enough. Indeed, there would be little
# interest in inspecting the important features of a non-predictive model.
#
# By default, random forests subsample a part of the dataset to train each tree, a
# procedure known as bagging, leaving aside "out-of-bag" (oob) samples.
# These samples can be leveraged to compute an accuracy score independantly of the

Check failure on line 111 in examples/inspection/plot_permutation_importance.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

independantly ==> independently
# training samples, when setting the parameter `oob_score = True`.
# This score should be close to the test score.

print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}")
print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}")
print(f"RF out-of-bag accuracy: {rf[-1].oob_score_:.3f}")

# %%
# Here, one can observe that the train accuracy is very high (the forest model
Expand Down Expand Up @@ -151,96 +165,101 @@
# %%
ax = mdi_importances.plot.barh()
ax.set_title("Random Forest Feature Importances (MDI)")
ax.set_xlabel("Decrease in impurity")
ax.figure.tight_layout()

# %%
# As an alternative, the permutation importances of ``rf`` are computed on a
# held out test set. This shows that the low cardinality categorical feature,
# `sex` and `pclass` are the most important features. Indeed, permuting the
# values of these features will lead to the most decrease in accuracy score of the
# model on the test set.
#
# Also, note that both random features have very low importances (close to 0) as
# expected.
# To avoid this issue, we can compute permutation importance instead. But we need to be
# careful as doing so on the train data will give wrong results.
# Indeed we can see that permutation importance on train data inflates the importance of
# every feature, even the random ones. Therefore one must be careful to use test data.
import matplotlib.pyplot as plt

from sklearn.inspection import permutation_importance

result = permutation_importance(
result_train = permutation_importance(
rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2
)
result_test = permutation_importance(
rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)

sorted_importances_idx = result.importances_mean.argsort()
importances = pd.DataFrame(
result.importances[sorted_importances_idx].T,
sorted_importances_idx = result_test.importances_mean.argsort()
importances_train = pd.DataFrame(
result_train.importances[sorted_importances_idx].T,
columns=X.columns[sorted_importances_idx],
)
ax = importances.plot.box(vert=False, whis=10)
ax.set_title("Permutation Importances (test set)")
ax.axvline(x=0, color="k", linestyle="--")
ax.set_xlabel("Decrease in accuracy score")
ax.figure.tight_layout()

# %%
# It is also possible to compute the permutation importances on the training
# set. This reveals that `random_num` and `random_cat` get a significantly
# higher importance ranking than when computed on the test set. The difference
# between those two plots is a confirmation that the RF model has enough
# capacity to use that random numerical and categorical features to overfit.
result = permutation_importance(
rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2
)

sorted_importances_idx = result.importances_mean.argsort()
importances = pd.DataFrame(
result.importances[sorted_importances_idx].T,
importances_test = pd.DataFrame(
result_test.importances[sorted_importances_idx].T,
columns=X.columns[sorted_importances_idx],
)
ax = importances.plot.box(vert=False, whis=10)
ax.set_title("Permutation Importances (train set)")
ax.axvline(x=0, color="k", linestyle="--")
ax.set_xlabel("Decrease in accuracy score")
ax.figure.tight_layout()

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
importances_train.plot.box(vert=False, whis=10, ax=ax[0])
ax[0].set_title("Permutation Importances (train set)")
ax[0].axvline(x=0, color="k", linestyle="--")
ax[0].set_xlabel("Decrease in accuracy score")

importances_test.plot.box(vert=False, whis=10, ax=ax[1])
ax[1].set_title("Permutation Importances (test set)")
ax[1].axvline(x=0, color="k", linestyle="--")
ax[1].set_xlabel("Decrease in accuracy score")
fig.tight_layout()
# %%
# We can further retry the experiment by limiting the capacity of the trees
# to overfit by setting `min_samples_leaf` at 20 data points.
# To see how this problem relates to overfitting, we can set `min_samples_leaf` at 20
# data points to reduce the overfitting of the model.
rf.set_params(classifier__min_samples_leaf=20).fit(X_train, y_train)

# %%
# Observing the accuracy score on the training and testing set, we observe that
# Looking at the accuracy score on the training and testing set, we observe that
# the two metrics are very similar now. Therefore, our model is not overfitting
# anymore. We can then check the permutation importances with this new model.
# anymore.
print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}")
print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}")

# %%
train_result = permutation_importance(
# We can see that our model is now much less reliant on uninformative features and
# therefore assigns lower importance to those. But we still have non zero importance
# values for completely random features when using train data only.
mdi_importances = pd.Series(
rf[-1].feature_importances_, index=feature_names
).sort_values(ascending=True)

result_train = permutation_importance(
rf, X_train, y_train, n_repeats=10, random_state=42, n_jobs=2
)
test_results = permutation_importance(
rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=2
)
sorted_importances_idx = train_result.importances_mean.argsort()

# %%
train_importances = pd.DataFrame(
train_result.importances[sorted_importances_idx].T,
columns=X.columns[sorted_importances_idx],
)
test_importances = pd.DataFrame(
test_results.importances[sorted_importances_idx].T,
sorted_importances_idx = result_train.importances_mean.argsort()
importances_train = pd.DataFrame(
result_train.importances[sorted_importances_idx].T,
columns=X.columns[sorted_importances_idx],
)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
ax[0] = mdi_importances.plot.barh(ax=ax[0])
ax[0].set_xlabel("Decrease in impurity")
ax[0].set_title("Random Forest Feature Importances (MDI)")
importances_train.plot.box(vert=False, whis=10, ax=ax[1])
ax[1].set_title("Permutation Importances (train set)")
ax[1].axvline(x=0, color="k", linestyle="--")
ax[1].set_xlabel("Decrease in accuracy score")
fig.tight_layout()

# %%
for name, importances in zip(["train", "test"], [train_importances, test_importances]):
ax = importances.plot.box(vert=False, whis=10)
ax.set_title(f"Permutation Importances ({name} set)")
ax.set_xlabel("Decrease in accuracy score")
ax.axvline(x=0, color="k", linestyle="--")
ax.figure.tight_layout()
# To completely ignore irrelevent features we should compute the permutation importance

Check failure on line 247 in examples/inspection/plot_permutation_importance.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

irrelevent ==> irrelevant
# of ``rf`` on a held out test set.
# However when test samples are not available, or when permutation importance becomes
# too expensive to compute, there exists a modified version of the MDI,
# `unbiased_feature_importances_` avaible as soon as `oob_score` is set to `True`,

Check failure on line 251 in examples/inspection/plot_permutation_importance.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

avaible ==> available
# that uses the out-of-bag samples of the trees to solve the bias problem.
ufi = rf[-1].unbiased_feature_importances_
mdi_importances = pd.Series(ufi, index=feature_names).sort_values(ascending=True)

# %%
# Now, we can observe that on both sets, the `random_num` and `random_cat`
# features have a lower importance compared to the overfitting random forest.
# However, the conclusions regarding the importance of the other features are
# still valid.
ax = mdi_importances.plot.barh()
ax.set_title("Unbiased Feature Importances (UFI)")
ax.axvline(x=0, color="k", linestyle="--")
ax.set_xlabel("Decrease in impurity")
ax.figure.tight_layout()
# %%
# We can see that the random features have an importance of zero and the important
# features are ordered in the same way as with permutation importance. This method
# is much faster than permutation importances but is limited to random forests.
23 changes: 23 additions & 0 deletions examples/tree/plot_unveil_tree_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- the leaf that was reached by a sample using the apply method;
- the rules that were used to predict a sample;
- the decision path shared by a group of samples.
- the importance of features computed on a test set

"""

Expand All @@ -24,6 +25,7 @@

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

Expand Down Expand Up @@ -235,3 +237,24 @@
)
)
print("This is {prop}% of all nodes.".format(prop=100 * len(common_node_id) / n_nodes))

##############################################################################
# Feature importance on a test set
# -------------
# The `compute_unbiased_feature_importance` method allows us to evaluate the relative
# importance of each feature of the dataset in a non biased way. By looking at how much
# purity was created in the children of a node, we can assess the ability of the feature
# used in the node to explain the data. Doing so purely on the data the tree was trained
# on might not generalize well, which is why the method
# `compute_unbiased_feature_importance` takes test samples as input to correct this
# effect. Additionally the method returns the prediction of the tree on these test
# samples wich allows to see the performance of the model.

Check failure on line 251 in examples/tree/plot_unveil_tree_structure.py

View workflow job for this annotation

GitHub Actions / Check for spelling errors

wich ==> which

importance, prediction = clf.compute_unbiased_feature_importance(X_test, y_test)
for i in range(len(importance)):
print(f"Importance of feature {i}: {importance[i]:.3f}")
print(
"Predicted probabilities of classes for the first test sample : \n", prediction[0]
)
score = accuracy_score(np.argmax(prediction, axis=1), y_test)
print(f"Score of the predictions on the test set: {score:.3f}")
Loading
Loading