Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
00c56de
ENH Adds files
thomasjpfan Feb 8, 2019
29be4f4
ENH Adds permutation importance
thomasjpfan Feb 12, 2019
2e09bfb
RFC Better names
thomasjpfan Feb 12, 2019
f7bb490
STY Flake8
thomasjpfan Feb 12, 2019
6f0175c
ENH: Adds inspect module
thomasjpfan Feb 12, 2019
bf44eb1
DOC Adds pre_dispatch
thomasjpfan Feb 12, 2019
85ed781
DOC Adds permutation importance example
thomasjpfan Feb 12, 2019
66e71dd
Trigger CI
thomasjpfan Feb 13, 2019
a93a9f3
BLD Adds inspect to configuration
thomasjpfan Feb 13, 2019
ee1e77f
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Feb 19, 2019
0670997
RFC Update to only inspect fitted model
thomasjpfan Feb 19, 2019
334c8c3
RFC Removes parameters
thomasjpfan Feb 19, 2019
354ac62
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Mar 1, 2019
260fa54
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Mar 14, 2019
92894a1
ENH: Adds pandas support
thomasjpfan Mar 14, 2019
f45c653
STY Flake8
thomasjpfan Mar 14, 2019
50d8550
DOC Adds new permutation importance example
thomasjpfan Mar 15, 2019
74e915f
ENH Renames module to model_inspection
thomasjpfan Mar 15, 2019
2a7d8e2
DOC Fix links
thomasjpfan Mar 15, 2019
920362a
DOC Fixes image link
thomasjpfan Mar 15, 2019
747599b
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Mar 15, 2019
a48e151
DOC Fixes image link
thomasjpfan Mar 15, 2019
51b745d
DOC Spelling
thomasjpfan Mar 16, 2019
23c8d11
DOC
thomasjpfan Mar 17, 2019
4241414
TST Fix keyword
thomasjpfan Mar 17, 2019
a12bc0c
Rework RF Imp vs Perm Imp example (#4)
ogrisel Apr 1, 2019
e864071
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Apr 1, 2019
9a57e20
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 9, 2019
5798338
WIP
thomasjpfan May 9, 2019
72b9003
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 9, 2019
37d52ba
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 10, 2019
fe675f9
WIP
thomasjpfan May 10, 2019
ced888d
WIP
thomasjpfan May 13, 2019
b0357fc
DOC Adds multcollinear features example
thomasjpfan May 15, 2019
91bf4e2
WIP
thomasjpfan May 15, 2019
a1d5880
DOC: Clean up docs
thomasjpfan May 15, 2019
4eb1e82
TST Adds tests for strings
thomasjpfan May 15, 2019
6f98f11
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 15, 2019
1656985
STY Indent correction
thomasjpfan May 16, 2019
0d34d80
WIP
thomasjpfan May 18, 2019
62868f6
ENH Uses check_X_y
thomasjpfan May 18, 2019
e7efe6d
TST Adds test with strings
thomasjpfan May 18, 2019
d75b557
STY Fix
thomasjpfan May 22, 2019
e3bbcda
TST Adds column transformer to test
thomasjpfan May 22, 2019
6c60e43
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan May 28, 2019
ed469d6
CLN Address comments
thomasjpfan May 29, 2019
6180975
CLN Removes import
thomasjpfan May 29, 2019
24d740e
TST Adds test with nan
thomasjpfan May 29, 2019
914335d
CLN Removes import
thomasjpfan May 29, 2019
f0beac6
ENH Parallel
thomasjpfan May 29, 2019
ac8d5a3
DOC comments
thomasjpfan May 29, 2019
31e9408
ENH Better handling of pandas
thomasjpfan May 29, 2019
be3f65b
ENH Clear checking of pandas dataframe
thomasjpfan May 29, 2019
e1df6a6
STY Formatting
thomasjpfan May 29, 2019
78aba62
ENH Copies in parallel helper
thomasjpfan May 29, 2019
d6ca3c5
DOC Adds comments
thomasjpfan May 29, 2019
a2aa960
BUG Fix copying
thomasjpfan May 29, 2019
9ff6aa1
BUG Fix for pandas
thomasjpfan May 29, 2019
f112cd3
BUG Fix for pandas
thomasjpfan May 29, 2019
884d648
REV
thomasjpfan May 29, 2019
c64e6a1
BLD Trigger CI
thomasjpfan May 29, 2019
d2fad37
BUG Fix
thomasjpfan May 30, 2019
50b6b98
BUG Fix
thomasjpfan May 30, 2019
14b3efd
TST Does this work
thomasjpfan May 30, 2019
f41f5b3
BUG Fixes test
thomasjpfan May 30, 2019
3cd43ce
BUG Fixes test
thomasjpfan May 30, 2019
318c961
BUG Fix
thomasjpfan May 30, 2019
aa6c79d
BUG Fix
thomasjpfan May 30, 2019
5292136
BUG Fix
thomasjpfan May 30, 2019
9b53e35
Merge branch 'permutation_importance_v2' into permutation_importance
thomasjpfan May 30, 2019
bc3ea96
STY Fix
thomasjpfan May 30, 2019
7d79a49
TST Fix
thomasjpfan May 30, 2019
b487618
TST Fix segfault
thomasjpfan May 31, 2019
7a83608
CLN Address comments
thomasjpfan Jun 17, 2019
af9c961
CLN Address comments
thomasjpfan Jun 17, 2019
664d863
ENH Returns a bunch
thomasjpfan Jun 17, 2019
8a022c6
STY Flake8
thomasjpfan Jun 17, 2019
78ed4e8
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jun 17, 2019
fbebc5e
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jun 24, 2019
d62df83
CLN Renames bunch key
thomasjpfan Jun 25, 2019
118601a
DOC Updates api
thomasjpfan Jun 25, 2019
9f1325f
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jun 25, 2019
2655f82
DOC Updates api
thomasjpfan Jun 25, 2019
ca9a78b
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 1, 2019
1748227
TST Adds permutation test with linear_regression
thomasjpfan Jul 2, 2019
e1607ff
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 2, 2019
fb4f926
DOC update
thomasjpfan Jul 4, 2019
eb154a9
DOC Fix label cutoff
thomasjpfan Jul 4, 2019
b1f9c70
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 4, 2019
946ca59
CLN Address comments
thomasjpfan Jul 9, 2019
5676930
TST Adds test for random_state effect
thomasjpfan Jul 9, 2019
78cefef
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 9, 2019
204c3ab
DOC Adds permutation importance
thomasjpfan Jul 9, 2019
dab6801
DOC Adds ogrisel suggestion
thomasjpfan Jul 9, 2019
c67667f
Merge remote-tracking branch 'upstream/master' into permutation_impor…
thomasjpfan Jul 16, 2019
f90eacf
DOC Address guillaumes comments
thomasjpfan Jul 16, 2019
6b428d7
DOC Address andreas comments
thomasjpfan Jul 16, 2019
94c4c56
DOC Update
thomasjpfan Jul 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/inspection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@
Inspection
----------

Predictive performance is often the main goal of developing machine learning
models. Yet summarising performance with an evaluation metric is often
insufficient: it assumes that the evaluation metric and test dataset
perfectly reflect the target domain, which is rarely true. In certain domains,
a model needs a certain level of interpretability before it can be deployed.
A model that is exhibiting performance issues needs to be debugged for one to
understand the model's underlying issue. The
:mod:`sklearn.inspection` module provides tools to help understand the
predictions from a model and what affects them. This can be used to
evaluate assumptions and biases of a model, design a better model, or
to diagnose issues with model performance.

.. toctree::

modules/partial_dependence
modules/permutation_importance
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ Kernels:
:template: function.rst

inspection.partial_dependence
inspection.permutation_importance
inspection.plot_partial_dependence


Expand Down Expand Up @@ -1257,7 +1258,6 @@ Model validation
pipeline.make_pipeline
pipeline.make_union


.. _preprocessing_ref:

:mod:`sklearn.preprocessing`: Preprocessing and Normalization
Expand Down
69 changes: 69 additions & 0 deletions doc/modules/permutation_importance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@

.. _permutation_importance:

Permutation feature importance
==============================

.. currentmodule:: sklearn.inspection

Permutation feature importance is a model inspection technique that can be used
for any `fitted` `estimator` when the data is rectangular. This is especially
useful for non-linear or opaque `estimators`. The permutation feature
importance is defined to be the decrease in a model score when a single feature
value is randomly shuffled [1]_. This procedure breaks the relationship between
the feature and the target, thus the drop in the model score is indicative of
how much the model depends on the feature. This technique benefits from being
model agnostic and can be calculated many times with different permutations of
the feature.

The :func:`permutation_importance` function calculates the feature importance
of `estimators` for a given dataset. The ``n_repeats`` parameter sets the number
of times a feature is randomly shuffled and returns a sample of feature
importances. Permutation importances can either be computed on the training set
or an held-out testing or validation set. Using a held-out set makes it
possible to highlight which features contribute the most to the generalization
power of the inspected model. Features that are important on the training set
but not on the held-out set might cause the model to overfit.

Note that features that are deemed non-important for some model with a
low predictive performance could be highly predictive for a model that
generalizes better. The conclusions should always be drawn in the context of
the specific model under inspection and cannot be automatically generalized to
the intrinsic predictive value of the features by them-selves. Therefore it is
always important to evaluate the predictive power of a model using a held-out
set (or better with cross-validation) prior to computing importances.

Relation to impurity-based importance in trees
----------------------------------------------

Tree based models provides a different measure of feature importances based
on the mean decrease in impurity (MDI, the splitting criterion). This gives
importance to features that may not be predictive on unseen data. The
permutation feature importance avoids this issue, since it can be applied to
unseen data. Furthermore, impurity-based feature importance for trees
are strongly biased and favor high cardinality features
(typically numerical features). Permutation-based feature importances do not
exhibit such a bias. Additionally, the permutation feature importance may use
an arbitrary metric on the tree's predictions. These two methods of obtaining
feature importance are explored in:
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`.

Strongly correlated features
----------------------------

When two features are correlated and one of the features is permuted, the model
will still have access to the feature through its correlated feature. This will
result in a lower importance for both features, where they might *actually* be
important. One way to handle this is to cluster features that are correlated
and only keep one feature from each cluster. This use case is explored in:
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py`.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`
* :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py`

.. topic:: References:

.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
2001. https://doi.org/10.1023/A:1010933404324
177 changes: 177 additions & 0 deletions examples/inspection/plot_permutation_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
================================================================
Permutation Importance vs Random Forest Feature Importance (MDI)
================================================================

In this example, we will compare the impurity-based feature importance of
:class:`~sklearn.ensemble.RandomForestClassifier` with the
permutation importance on the titanic dataset using
:func:`~sklearn.inspection.permutation_importance`. We will show that the
impurity-based feature importance can inflate the importance of numerical
features.

Furthermore, the impurity-based feature importance of random forests suffers
from being computed on statistics derived from the training dataset: the
importances can be high even for features that are not predictive of the target
variable, as long as the model has the capacity to use them to overfit.

This example shows how to use Permutation Importances as an alternative that
can mitigate those limitations.

.. topic:: References:

.. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32,
2001. https://doi.org/10.1023/A:1010933404324
"""
print(__doc__)
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.inspection import permutation_importance
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder


##############################################################################
# Data Loading and Feature Engineering
# ------------------------------------
# Let's use pandas to load a copy of the titanic dataset. The following shows
# how to apply separate preprocessing on numerical and categorical features.
#
# We further include two random variables that are not correlated in any way
# with the target variable (``survived``):
#
# - ``random_num`` is a high cardinality numerical variable (as many unique
# values as records).
# - ``random_cat`` is a low cardinality categorical variable (3 possible
# values).
X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True)
X['random_cat'] = np.random.randint(3, size=X.shape[0])
X['random_num'] = np.random.randn(X.shape[0])

categorical_columns = ['pclass', 'sex', 'embarked', 'random_cat']
numerical_columns = ['age', 'sibsp', 'parch', 'fare', 'random_num']

X = X[categorical_columns + numerical_columns]

X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=42)

categorical_pipe = Pipeline([
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
numerical_pipe = Pipeline([
('imputer', SimpleImputer(strategy='mean'))
])

preprocessing = ColumnTransformer(
[('cat', categorical_pipe, categorical_columns),
('num', numerical_pipe, numerical_columns)])

rf = Pipeline([
('preprocess', preprocessing),
('classifier', RandomForestClassifier(random_state=42))
])
rf.fit(X_train, y_train)

##############################################################################
# Accuracy of the Model
# ---------------------
# Prior to inspecting the feature importances, it is important to check that
# the model predictive performance is high enough. Indeed there would be little
# interest of inspecting the important features of a non-predictive model.
#
# Here one can observe that the train accuracy is very high (the forest model
# has enough capacity to completely memorize the training set) but it can still
# generalize well enough to the test set thanks to the built-in bagging of
# random forests.
#
# It might be possible to trade some accuracy on the training set for a
# slightly better accuracy on the test set by limiting the capacity of the
# trees (for instance by setting ``min_samples_leaf=5`` or
# ``min_samples_leaf=10``) so as to limit overfitting while not introducing too
# much underfitting.
#
# However let's keep our high capacity random forest model for now so as to
# illustrate some pitfalls with feature importance on variables with many
# unique values.
print("RF train accuracy: %0.3f" % rf.score(X_train, y_train))
print("RF test accuracy: %0.3f" % rf.score(X_test, y_test))


##############################################################################
# Tree's Feature Importance from Mean Decrease in Impurity (MDI)
# --------------------------------------------------------------
# The impurity-based feature importance ranks the numerical features to be the
# most important features. As a result, the non-predictive ``random_num``
# variable is ranked the most important!
#
# This problem stems from two limitations of impurity-based feature
# importances:
#
# - impurity-based importances are biased towards high cardinality features;
# - impurity-based importances are computed on training set statistics and
# therefore do not reflect the ability of feature to be useful to make
# predictions that generalize to the test set (when the model has enough
# capacity).
ohe = (rf.named_steps['preprocess']
.named_transformers_['cat']
.named_steps['onehot'])
feature_names = ohe.get_feature_names(input_features=categorical_columns)
feature_names = np.r_[feature_names, numerical_columns]

tree_feature_importances = (
rf.named_steps['classifier'].feature_importances_)
sorted_idx = tree_feature_importances.argsort()

y_ticks = np.arange(0, len(feature_names))
fig, ax = plt.subplots()
ax.barh(y_ticks, tree_feature_importances[sorted_idx])
ax.set_yticklabels(feature_names[sorted_idx])
ax.set_yticks(y_ticks)
ax.set_title("Random Forest Feature Importances (MDI)")
fig.tight_layout()
plt.show()


##############################################################################
# As an alternative, the permutation importances of ``rf`` are computed on a
# held out test set. This shows that the low cardinality categorical feature,
# ``sex`` is the most important feature.
#
# Also note that both random features have very low importances (close to 0) as
# expected.
result = permutation_importance(rf, X_test, y_test, n_repeats=10,
random_state=42, n_jobs=2)
sorted_idx = result.importances_mean.argsort()

fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
vert=False, labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()

##############################################################################
# It is also possible to compute the permutation importances on the training
# set. This reveals that ``random_num`` gets a significantly higher importance
# ranking than when computed on the test set. The difference between those two
# plots is a confirmation that the RF model has enough capacity to use that
# random numerical feature to overfit. You can further confirm this by
# re-running this example with constrained RF with min_samples_leaf=10.
result = permutation_importance(rf, X_train, y_train, n_repeats=10,
random_state=42, n_jobs=2)
sorted_idx = result.importances_mean.argsort()

fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
vert=False, labels=X_train.columns[sorted_idx])
ax.set_title("Permutation Importances (train set)")
fig.tight_layout()
plt.show()
111 changes: 111 additions & 0 deletions examples/inspection/plot_permutation_importance_multicollinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
=================================================================
Permutation Importance with Multicollinear or Correlated Features
=================================================================

In this example, we compute the permutation importance on the Wisconsin
breast cancer dataset using :func:`~sklearn.inspection.permutation_importance`.
The :class:`~sklearn.ensemble.RandomForestClassifier` can easily get about 97%
accuracy on a test dataset. Because this dataset contains multicollinear
features, the permutation importance will show that none of the features are
important. One approach to handling multicollinearity is by performing
hierarchical clustering on the features' Spearman rank-order correlations,
picking a threshold, and keeping a single feature from each cluster.

.. note::
See also
:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`
"""
print(__doc__)
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr
from scipy.cluster import hierarchy

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split

##############################################################################
# Random Forest Feature Importance on Breast Cancer Data
# ------------------------------------------------------
# First, we train a random forest on the breast cancer dataset and evaluate
# its accuracy on a test set:
data = load_breast_cancer()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

clf = RandomForestClassifier(n_estimators=100, random_state=42)
clf.fit(X_train, y_train)
print("Accuracy on test data: {:.2f}".format(clf.score(X_test, y_test)))

##############################################################################
# Next, we plot the tree based feature importance and the permutation
# importance. The permutation importance plot shows that permuting a feature
# drops the accuracy by at most `0.012`, which would suggest that none of the
# features are important. This is in contradiction with the high test accuracy
# computed above: some feature must be important. The permutation importance
# is calculated on the training set to show how much the model relies on each
# feature during training.
result = permutation_importance(clf, X_train, y_train, n_repeats=10,
random_state=42)
perm_sorted_idx = result.importances_mean.argsort()

tree_importance_sorted_idx = np.argsort(clf.feature_importances_)
tree_indicies = np.arange(1, len(clf.feature_importances_) + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
ax1.barh(tree_indicies, clf.feature_importances_[tree_importance_sorted_idx])
ax1.set_yticklabels(data.feature_names)
ax1.set_yticks(tree_indicies)
ax2.boxplot(result.importances[perm_sorted_idx].T, vert=False,
labels=data.feature_names)
fig.tight_layout()
plt.show()

##############################################################################
# Handling Multicollinear Features
# --------------------------------
# When features are collinear, permutating one feature will have little
# effect on the models performance because it can get the same information
# from a correlated feature. One way to handle multicollinear features is by
# performing hierarchical clustering on the Spearman rank-order correlations,
# picking a threshold, and keeping a single feature from each cluster. First,
# we plot a heatmap of the correlated features:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))
corr = spearmanr(X).correlation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now there is no example doing just this, right? Might be worth adding this to the feature selection section of the user guide?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could add a feature selection method that actually implements this ;) I'm pretty sure there's an issue on that. Also would be a nice plotting function...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a feature selection method based on correlation is on my todo list.

A data exploring plotting function? Hmm maybe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding this to the feature selection section of the user guide?

This is kind of done manually in this example. (Not easy to place into a pipeline)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A data exploring plotting function? Hmm maybe.

sounds like a job for dabl, not sklearn ;)

Also see https://github.com/neurodata-nomads/pymeda

corr_linkage = hierarchy.ward(corr)
dendro = hierarchy.dendrogram(corr_linkage, labels=data.feature_names, ax=ax1,
leaf_rotation=90)
dendro_idx = np.arange(0, len(dendro['ivl']))

ax2.imshow(corr[dendro['leaves'], :][:, dendro['leaves']])
ax2.set_xticks(dendro_idx)
ax2.set_yticks(dendro_idx)
ax2.set_xticklabels(dendro['ivl'], rotation='vertical')
ax2.set_yticklabels(dendro['ivl'])
fig.tight_layout()
plt.show()

##############################################################################
# Next, we manually pick a threshold by visual inspection of the dendrogram
# to group our features into clusters and choose a feature from each cluster to
# keep, select those features from our dataset, and train a new random forest.
# The test accuracy of the new random forest did not change much compared to
# the random forest trained on the complete dataset.
cluster_ids = hierarchy.fcluster(corr_linkage, 1, criterion='distance')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you pick the threshold?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Visually

Copy link
Member

@ogrisel ogrisel Jul 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for making this choice more explicit in the comment.

cluster_id_to_feature_ids = defaultdict(list)
for idx, cluster_id in enumerate(cluster_ids):
cluster_id_to_feature_ids[cluster_id].append(idx)
selected_features = [v[0] for v in cluster_id_to_feature_ids.values()]

X_train_sel = X_train[:, selected_features]
X_test_sel = X_test[:, selected_features]

clf_sel = RandomForestClassifier(n_estimators=100, random_state=42)
clf_sel.fit(X_train_sel, y_train)
print("Accuracy on test data with features removed: {:.2f}".format(
clf_sel.score(X_test_sel, y_test)))
Loading