-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
[MRG] Adds Permutation Importance #13146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
00c56de
29be4f4
2e09bfb
f7bb490
6f0175c
bf44eb1
85ed781
66e71dd
a93a9f3
ee1e77f
0670997
334c8c3
354ac62
260fa54
92894a1
f45c653
50d8550
74e915f
2a7d8e2
920362a
747599b
a48e151
51b745d
23c8d11
4241414
a12bc0c
e864071
9a57e20
5798338
72b9003
37d52ba
fe675f9
ced888d
b0357fc
91bf4e2
a1d5880
4eb1e82
6f98f11
1656985
0d34d80
62868f6
e7efe6d
d75b557
e3bbcda
6c60e43
ed469d6
6180975
24d740e
914335d
f0beac6
ac8d5a3
31e9408
be3f65b
e1df6a6
78aba62
d6ca3c5
a2aa960
9ff6aa1
f112cd3
884d648
c64e6a1
d2fad37
50b6b98
14b3efd
f41f5b3
3cd43ce
318c961
aa6c79d
5292136
9b53e35
bc3ea96
7d79a49
b487618
7a83608
af9c961
664d863
8a022c6
78ed4e8
fbebc5e
d62df83
118601a
9f1325f
2655f82
ca9a78b
1748227
e1607ff
fb4f926
eb154a9
b1f9c70
946ca59
5676930
78cefef
204c3ab
dab6801
c67667f
f90eacf
6b428d7
94c4c56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
|
||
.. _permutation_importance: | ||
|
||
Permutation feature importance | ||
============================== | ||
|
||
.. currentmodule:: sklearn.inspection | ||
|
||
Permutation feature importance is a model inspection technique that can be used | ||
for any `fitted` `estimator` when the data is rectangular. This is especially | ||
useful for non-linear or opaque `estimators`. The permutation feature | ||
importance is defined to be the decrease in a model score when a single feature | ||
value is randomly shuffled [1]_. This procedure breaks the relationship between | ||
the feature and the target, thus the drop in the model score is indicative of | ||
how much the model depends on the feature. This technique benefits from being | ||
model agnostic and can be calculated many times with different permutations of | ||
the feature. | ||
|
||
The :func:`permutation_importance` function calculates the feature importance | ||
of `estimators` for a given dataset. The ``n_repeats`` parameter sets the number | ||
of times a feature is randomly shuffled and returns a sample of feature | ||
importances. Permutation importances can either be computed on the training set | ||
or an held-out testing or validation set. Using a held-out set makes it | ||
possible to highlight which features contribute the most to the generalization | ||
power of the inspected model. Features that are important on the training set | ||
but not on the held-out set might cause the model to overfit. | ||
|
||
Note that features that are deemed non-important for some model with a | ||
low predictive performance could be highly predictive for a model that | ||
generalizes better. The conclusions should always be drawn in the context of | ||
the specific model under inspection and cannot be automatically generalized to | ||
the intrinsic predictive value of the features by them-selves. Therefore it is | ||
always important to evaluate the predictive power of a model using a held-out | ||
set (or better with cross-validation) prior to computing importances. | ||
|
||
Relation to impurity-based importance in trees | ||
---------------------------------------------- | ||
|
||
Tree based models provides a different measure of feature importances based | ||
on the mean decrease in impurity (MDI, the splitting criterion). This gives | ||
importance to features that may not be predictive on unseen data. The | ||
permutation feature importance avoids this issue, since it can be applied to | ||
unseen data. Furthermore, impurity-based feature importance for trees | ||
are strongly biased and favor high cardinality features | ||
(typically numerical features). Permutation-based feature importances do not | ||
exhibit such a bias. Additionally, the permutation feature importance may use | ||
an arbitrary metric on the tree's predictions. These two methods of obtaining | ||
feature importance are explored in: | ||
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`. | ||
|
||
Strongly correlated features | ||
---------------------------- | ||
|
||
When two features are correlated and one of the features is permuted, the model | ||
will still have access to the feature through its correlated feature. This will | ||
result in a lower importance for both features, where they might *actually* be | ||
important. One way to handle this is to cluster features that are correlated | ||
and only keep one feature from each cluster. This use case is explored in: | ||
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py`. | ||
|
||
.. topic:: Examples: | ||
|
||
* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` | ||
* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py` | ||
|
||
.. topic:: References: | ||
|
||
.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, | ||
2001. https://doi.org/10.1023/A:1010933404324 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
""" | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
================================================================ | ||
Permutation Importance vs Random Forest Feature Importance (MDI) | ||
================================================================ | ||
|
||
In this example, we will compare the impurity-based feature importance of | ||
:class:`~sklearn.ensemble.RandomForestClassifier` with the | ||
permutation importance on the titanic dataset using | ||
:func:`~sklearn.inspection.permutation_importance`. We will show that the | ||
impurity-based feature importance can inflate the importance of numerical | ||
features. | ||
|
||
Furthermore, the impurity-based feature importance of random forests suffers | ||
from being computed on statistics derived from the training dataset: the | ||
importances can be high even for features that are not predictive of the target | ||
variable, as long as the model has the capacity to use them to overfit. | ||
|
||
This example shows how to use Permutation Importances as an alternative that | ||
can mitigate those limitations. | ||
|
||
.. topic:: References: | ||
|
||
.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, | ||
2001. https://doi.org/10.1023/A:1010933404324 | ||
""" | ||
print(__doc__) | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn.datasets import fetch_openml | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.impute import SimpleImputer | ||
from sklearn.inspection import permutation_importance | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.pipeline import Pipeline | ||
from sklearn.preprocessing import OneHotEncoder | ||
|
||
|
||
############################################################################## | ||
# Data Loading and Feature Engineering | ||
# ------------------------------------ | ||
# Let's use pandas to load a copy of the titanic dataset. The following shows | ||
# how to apply separate preprocessing on numerical and categorical features. | ||
# | ||
# We further include two random variables that are not correlated in any way | ||
# with the target variable (``survived``): | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# | ||
# - ``random_num`` is a high cardinality numerical variable (as many unique | ||
# values as records). | ||
# - ``random_cat`` is a low cardinality categorical variable (3 possible | ||
# values). | ||
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) | ||
X['random_cat'] = np.random.randint(3, size=X.shape[0]) | ||
X['random_num'] = np.random.randn(X.shape[0]) | ||
|
||
categorical_columns = ['pclass', 'sex', 'embarked', 'random_cat'] | ||
numerical_columns = ['age', 'sibsp', 'parch', 'fare', 'random_num'] | ||
|
||
X = X[categorical_columns + numerical_columns] | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, stratify=y, random_state=42) | ||
|
||
categorical_pipe = Pipeline([ | ||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')), | ||
('onehot', OneHotEncoder(handle_unknown='ignore')) | ||
]) | ||
numerical_pipe = Pipeline([ | ||
('imputer', SimpleImputer(strategy='mean')) | ||
]) | ||
|
||
preprocessing = ColumnTransformer( | ||
[('cat', categorical_pipe, categorical_columns), | ||
('num', numerical_pipe, numerical_columns)]) | ||
|
||
rf = Pipeline([ | ||
('preprocess', preprocessing), | ||
('classifier', RandomForestClassifier(random_state=42)) | ||
]) | ||
rf.fit(X_train, y_train) | ||
|
||
############################################################################## | ||
# Accuracy of the Model | ||
# --------------------- | ||
# Prior to inspecting the feature importances, it is important to check that | ||
# the model predictive performance is high enough. Indeed there would be little | ||
# interest of inspecting the important features of a non-predictive model. | ||
# | ||
# Here one can observe that the train accuracy is very high (the forest model | ||
# has enough capacity to completely memorize the training set) but it can still | ||
# generalize well enough to the test set thanks to the built-in bagging of | ||
# random forests. | ||
# | ||
# It might be possible to trade some accuracy on the training set for a | ||
# slightly better accuracy on the test set by limiting the capacity of the | ||
# trees (for instance by setting ``min_samples_leaf=5`` or | ||
# ``min_samples_leaf=10``) so as to limit overfitting while not introducing too | ||
# much underfitting. | ||
# | ||
# However let's keep our high capacity random forest model for now so as to | ||
# illustrate some pitfalls with feature importance on variables with many | ||
# unique values. | ||
print("RF train accuracy: %0.3f" % rf.score(X_train, y_train)) | ||
print("RF test accuracy: %0.3f" % rf.score(X_test, y_test)) | ||
|
||
|
||
############################################################################## | ||
# Tree's Feature Importance from Mean Decrease in Impurity (MDI) | ||
# -------------------------------------------------------------- | ||
# The impurity-based feature importance ranks the numerical features to be the | ||
# most important features. As a result, the non-predictive ``random_num`` | ||
# variable is ranked the most important! | ||
# | ||
# This problem stems from two limitations of impurity-based feature | ||
# importances: | ||
# | ||
# - impurity-based importances are biased towards high cardinality features; | ||
# - impurity-based importances are computed on training set statistics and | ||
# therefore do not reflect the ability of feature to be useful to make | ||
# predictions that generalize to the test set (when the model has enough | ||
# capacity). | ||
ohe = (rf.named_steps['preprocess'] | ||
.named_transformers_['cat'] | ||
.named_steps['onehot']) | ||
feature_names = ohe.get_feature_names(input_features=categorical_columns) | ||
feature_names = np.r_[feature_names, numerical_columns] | ||
|
||
tree_feature_importances = ( | ||
rf.named_steps['classifier'].feature_importances_) | ||
sorted_idx = tree_feature_importances.argsort() | ||
|
||
y_ticks = np.arange(0, len(feature_names)) | ||
fig, ax = plt.subplots() | ||
ax.barh(y_ticks, tree_feature_importances[sorted_idx]) | ||
ax.set_yticklabels(feature_names[sorted_idx]) | ||
ax.set_yticks(y_ticks) | ||
ax.set_title("Random Forest Feature Importances (MDI)") | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
|
||
############################################################################## | ||
# As an alternative, the permutation importances of ``rf`` are computed on a | ||
# held out test set. This shows that the low cardinality categorical feature, | ||
# ``sex`` is the most important feature. | ||
# | ||
# Also note that both random features have very low importances (close to 0) as | ||
# expected. | ||
result = permutation_importance(rf, X_test, y_test, n_repeats=10, | ||
random_state=42, n_jobs=2) | ||
sorted_idx = result.importances_mean.argsort() | ||
|
||
fig, ax = plt.subplots() | ||
ax.boxplot(result.importances[sorted_idx].T, | ||
vert=False, labels=X_test.columns[sorted_idx]) | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ax.set_title("Permutation Importances (test set)") | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
############################################################################## | ||
# It is also possible to compute the permutation importances on the training | ||
# set. This reveals that ``random_num`` gets a significantly higher importance | ||
# ranking than when computed on the test set. The difference between those two | ||
# plots is a confirmation that the RF model has enough capacity to use that | ||
# random numerical feature to overfit. You can further confirm this by | ||
# re-running this example with constrained RF with min_samples_leaf=10. | ||
result = permutation_importance(rf, X_train, y_train, n_repeats=10, | ||
random_state=42, n_jobs=2) | ||
sorted_idx = result.importances_mean.argsort() | ||
|
||
fig, ax = plt.subplots() | ||
ax.boxplot(result.importances[sorted_idx].T, | ||
vert=False, labels=X_train.columns[sorted_idx]) | ||
ax.set_title("Permutation Importances (train set)") | ||
fig.tight_layout() | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
================================================================= | ||
Permutation Importance with Multicollinear or Correlated Features | ||
================================================================= | ||
|
||
In this example, we compute the permutation importance on the Wisconsin | ||
breast cancer dataset using :func:`~sklearn.inspection.permutation_importance`. | ||
The :class:`~sklearn.ensemble.RandomForestClassifier` can easily get about 97% | ||
accuracy on a test dataset. Because this dataset contains multicollinear | ||
features, the permutation importance will show that none of the features are | ||
important. One approach to handling multicollinearity is by performing | ||
hierarchical clustering on the features' Spearman rank-order correlations, | ||
picking a threshold, and keeping a single feature from each cluster. | ||
|
||
.. note:: | ||
See also | ||
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` | ||
""" | ||
print(__doc__) | ||
from collections import defaultdict | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from scipy.stats import spearmanr | ||
from scipy.cluster import hierarchy | ||
|
||
from sklearn.datasets import load_breast_cancer | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.inspection import permutation_importance | ||
from sklearn.model_selection import train_test_split | ||
|
||
############################################################################## | ||
# Random Forest Feature Importance on Breast Cancer Data | ||
# ------------------------------------------------------ | ||
# First, we train a random forest on the breast cancer dataset and evaluate | ||
# its accuracy on a test set: | ||
data = load_breast_cancer() | ||
X, y = data.data, data.target | ||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) | ||
|
||
clf = RandomForestClassifier(n_estimators=100, random_state=42) | ||
clf.fit(X_train, y_train) | ||
print("Accuracy on test data: {:.2f}".format(clf.score(X_test, y_test))) | ||
|
||
############################################################################## | ||
# Next, we plot the tree based feature importance and the permutation | ||
# importance. The permutation importance plot shows that permuting a feature | ||
# drops the accuracy by at most `0.012`, which would suggest that none of the | ||
# features are important. This is in contradiction with the high test accuracy | ||
# computed above: some feature must be important. The permutation importance | ||
# is calculated on the training set to show how much the model relies on each | ||
# feature during training. | ||
result = permutation_importance(clf, X_train, y_train, n_repeats=10, | ||
random_state=42) | ||
perm_sorted_idx = result.importances_mean.argsort() | ||
|
||
tree_importance_sorted_idx = np.argsort(clf.feature_importances_) | ||
tree_indicies = np.arange(1, len(clf.feature_importances_) + 1) | ||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) | ||
ax1.barh(tree_indicies, clf.feature_importances_[tree_importance_sorted_idx]) | ||
ax1.set_yticklabels(data.feature_names) | ||
ax1.set_yticks(tree_indicies) | ||
ax2.boxplot(result.importances[perm_sorted_idx].T, vert=False, | ||
labels=data.feature_names) | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
############################################################################## | ||
# Handling Multicollinear Features | ||
# -------------------------------- | ||
# When features are collinear, permutating one feature will have little | ||
# effect on the models performance because it can get the same information | ||
# from a correlated feature. One way to handle multicollinear features is by | ||
# performing hierarchical clustering on the Spearman rank-order correlations, | ||
# picking a threshold, and keeping a single feature from each cluster. First, | ||
# we plot a heatmap of the correlated features: | ||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) | ||
corr = spearmanr(X).correlation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now there is no example doing just this, right? Might be worth adding this to the feature selection section of the user guide? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or we could add a feature selection method that actually implements this ;) I'm pretty sure there's an issue on that. Also would be a nice plotting function... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding a feature selection method based on correlation is on my todo list. A data exploring plotting function? Hmm maybe. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is kind of done manually in this example. (Not easy to place into a pipeline) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
sounds like a job for dabl, not sklearn ;) |
||
corr_linkage = hierarchy.ward(corr) | ||
dendro = hierarchy.dendrogram(corr_linkage, labels=data.feature_names, ax=ax1, | ||
leaf_rotation=90) | ||
dendro_idx = np.arange(0, len(dendro['ivl'])) | ||
|
||
ax2.imshow(corr[dendro['leaves'], :][:, dendro['leaves']]) | ||
ax2.set_xticks(dendro_idx) | ||
ax2.set_yticks(dendro_idx) | ||
ax2.set_xticklabels(dendro['ivl'], rotation='vertical') | ||
ax2.set_yticklabels(dendro['ivl']) | ||
fig.tight_layout() | ||
plt.show() | ||
|
||
############################################################################## | ||
# Next, we manually pick a threshold by visual inspection of the dendrogram | ||
# to group our features into clusters and choose a feature from each cluster to | ||
# keep, select those features from our dataset, and train a new random forest. | ||
# The test accuracy of the new random forest did not change much compared to | ||
# the random forest trained on the complete dataset. | ||
cluster_ids = hierarchy.fcluster(corr_linkage, 1, criterion='distance') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How did you pick the threshold? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Visually There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for making this choice more explicit in the comment. |
||
cluster_id_to_feature_ids = defaultdict(list) | ||
for idx, cluster_id in enumerate(cluster_ids): | ||
cluster_id_to_feature_ids[cluster_id].append(idx) | ||
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()] | ||
|
||
X_train_sel = X_train[:, selected_features] | ||
X_test_sel = X_test[:, selected_features] | ||
|
||
clf_sel = RandomForestClassifier(n_estimators=100, random_state=42) | ||
clf_sel.fit(X_train_sel, y_train) | ||
print("Accuracy on test data with features removed: {:.2f}".format( | ||
clf_sel.score(X_test_sel, y_test))) |
Uh oh!
There was an error while loading. Please reload this page.