diff --git a/doc/inspection.rst b/doc/inspection.rst index 745539d51bf77..b53aeb436b4cd 100644 --- a/doc/inspection.rst +++ b/doc/inspection.rst @@ -5,6 +5,19 @@ Inspection ---------- +Predictive performance is often the main goal of developing machine learning +models. Yet summarising performance with an evaluation metric is often +insufficient: it assumes that the evaluation metric and test dataset +perfectly reflect the target domain, which is rarely true. In certain domains, +a model needs a certain level of interpretability before it can be deployed. +A model that is exhibiting performance issues needs to be debugged for one to +understand the model's underlying issue. The +:mod:`sklearn.inspection` module provides tools to help understand the +predictions from a model and what affects them. This can be used to +evaluate assumptions and biases of a model, design a better model, or +to diagnose issues with model performance. + .. toctree:: modules/partial_dependence + modules/permutation_importance diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index d9c87362e5a11..30fc3b5102bc6 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -657,6 +657,7 @@ Kernels: :template: function.rst inspection.partial_dependence + inspection.permutation_importance inspection.plot_partial_dependence @@ -1257,7 +1258,6 @@ Model validation pipeline.make_pipeline pipeline.make_union - .. _preprocessing_ref: :mod:`sklearn.preprocessing`: Preprocessing and Normalization diff --git a/doc/modules/permutation_importance.rst b/doc/modules/permutation_importance.rst new file mode 100644 index 0000000000000..d1f850ffeb793 --- /dev/null +++ b/doc/modules/permutation_importance.rst @@ -0,0 +1,69 @@ + +.. _permutation_importance: + +Permutation feature importance +============================== + +.. currentmodule:: sklearn.inspection + +Permutation feature importance is a model inspection technique that can be used +for any `fitted` `estimator` when the data is rectangular. This is especially +useful for non-linear or opaque `estimators`. The permutation feature +importance is defined to be the decrease in a model score when a single feature +value is randomly shuffled [1]_. This procedure breaks the relationship between +the feature and the target, thus the drop in the model score is indicative of +how much the model depends on the feature. This technique benefits from being +model agnostic and can be calculated many times with different permutations of +the feature. + +The :func:`permutation_importance` function calculates the feature importance +of `estimators` for a given dataset. The ``n_repeats`` parameter sets the number +of times a feature is randomly shuffled and returns a sample of feature +importances. Permutation importances can either be computed on the training set +or an held-out testing or validation set. Using a held-out set makes it +possible to highlight which features contribute the most to the generalization +power of the inspected model. Features that are important on the training set +but not on the held-out set might cause the model to overfit. + +Note that features that are deemed non-important for some model with a +low predictive performance could be highly predictive for a model that +generalizes better. The conclusions should always be drawn in the context of +the specific model under inspection and cannot be automatically generalized to +the intrinsic predictive value of the features by them-selves. Therefore it is +always important to evaluate the predictive power of a model using a held-out +set (or better with cross-validation) prior to computing importances. + +Relation to impurity-based importance in trees +---------------------------------------------- + +Tree based models provides a different measure of feature importances based +on the mean decrease in impurity (MDI, the splitting criterion). This gives +importance to features that may not be predictive on unseen data. The +permutation feature importance avoids this issue, since it can be applied to +unseen data. Furthermore, impurity-based feature importance for trees +are strongly biased and favor high cardinality features +(typically numerical features). Permutation-based feature importances do not +exhibit such a bias. Additionally, the permutation feature importance may use +an arbitrary metric on the tree's predictions. These two methods of obtaining +feature importance are explored in: +:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py`. + +Strongly correlated features +---------------------------- + +When two features are correlated and one of the features is permuted, the model +will still have access to the feature through its correlated feature. This will +result in a lower importance for both features, where they might *actually* be +important. One way to handle this is to cluster features that are correlated +and only keep one feature from each cluster. This use case is explored in: +:ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py`. + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` + * :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance_multicollinear.py` + +.. topic:: References: + + .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, + 2001. https://doi.org/10.1023/A:1010933404324 diff --git a/examples/inspection/plot_permutation_importance.py b/examples/inspection/plot_permutation_importance.py new file mode 100644 index 0000000000000..c449573821a96 --- /dev/null +++ b/examples/inspection/plot_permutation_importance.py @@ -0,0 +1,177 @@ +""" +================================================================ +Permutation Importance vs Random Forest Feature Importance (MDI) +================================================================ + +In this example, we will compare the impurity-based feature importance of +:class:`~sklearn.ensemble.RandomForestClassifier` with the +permutation importance on the titanic dataset using +:func:`~sklearn.inspection.permutation_importance`. We will show that the +impurity-based feature importance can inflate the importance of numerical +features. + +Furthermore, the impurity-based feature importance of random forests suffers +from being computed on statistics derived from the training dataset: the +importances can be high even for features that are not predictive of the target +variable, as long as the model has the capacity to use them to overfit. + +This example shows how to use Permutation Importances as an alternative that +can mitigate those limitations. + +.. topic:: References: + + .. [1] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, + 2001. https://doi.org/10.1023/A:1010933404324 +""" +print(__doc__) +import matplotlib.pyplot as plt +import numpy as np + +from sklearn.datasets import fetch_openml +from sklearn.ensemble import RandomForestClassifier +from sklearn.impute import SimpleImputer +from sklearn.inspection import permutation_importance +from sklearn.compose import ColumnTransformer +from sklearn.model_selection import train_test_split +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import OneHotEncoder + + +############################################################################## +# Data Loading and Feature Engineering +# ------------------------------------ +# Let's use pandas to load a copy of the titanic dataset. The following shows +# how to apply separate preprocessing on numerical and categorical features. +# +# We further include two random variables that are not correlated in any way +# with the target variable (``survived``): +# +# - ``random_num`` is a high cardinality numerical variable (as many unique +# values as records). +# - ``random_cat`` is a low cardinality categorical variable (3 possible +# values). +X, y = fetch_openml("titanic", version=1, as_frame=True, return_X_y=True) +X['random_cat'] = np.random.randint(3, size=X.shape[0]) +X['random_num'] = np.random.randn(X.shape[0]) + +categorical_columns = ['pclass', 'sex', 'embarked', 'random_cat'] +numerical_columns = ['age', 'sibsp', 'parch', 'fare', 'random_num'] + +X = X[categorical_columns + numerical_columns] + +X_train, X_test, y_train, y_test = train_test_split( + X, y, stratify=y, random_state=42) + +categorical_pipe = Pipeline([ + ('imputer', SimpleImputer(strategy='constant', fill_value='missing')), + ('onehot', OneHotEncoder(handle_unknown='ignore')) +]) +numerical_pipe = Pipeline([ + ('imputer', SimpleImputer(strategy='mean')) +]) + +preprocessing = ColumnTransformer( + [('cat', categorical_pipe, categorical_columns), + ('num', numerical_pipe, numerical_columns)]) + +rf = Pipeline([ + ('preprocess', preprocessing), + ('classifier', RandomForestClassifier(random_state=42)) +]) +rf.fit(X_train, y_train) + +############################################################################## +# Accuracy of the Model +# --------------------- +# Prior to inspecting the feature importances, it is important to check that +# the model predictive performance is high enough. Indeed there would be little +# interest of inspecting the important features of a non-predictive model. +# +# Here one can observe that the train accuracy is very high (the forest model +# has enough capacity to completely memorize the training set) but it can still +# generalize well enough to the test set thanks to the built-in bagging of +# random forests. +# +# It might be possible to trade some accuracy on the training set for a +# slightly better accuracy on the test set by limiting the capacity of the +# trees (for instance by setting ``min_samples_leaf=5`` or +# ``min_samples_leaf=10``) so as to limit overfitting while not introducing too +# much underfitting. +# +# However let's keep our high capacity random forest model for now so as to +# illustrate some pitfalls with feature importance on variables with many +# unique values. +print("RF train accuracy: %0.3f" % rf.score(X_train, y_train)) +print("RF test accuracy: %0.3f" % rf.score(X_test, y_test)) + + +############################################################################## +# Tree's Feature Importance from Mean Decrease in Impurity (MDI) +# -------------------------------------------------------------- +# The impurity-based feature importance ranks the numerical features to be the +# most important features. As a result, the non-predictive ``random_num`` +# variable is ranked the most important! +# +# This problem stems from two limitations of impurity-based feature +# importances: +# +# - impurity-based importances are biased towards high cardinality features; +# - impurity-based importances are computed on training set statistics and +# therefore do not reflect the ability of feature to be useful to make +# predictions that generalize to the test set (when the model has enough +# capacity). +ohe = (rf.named_steps['preprocess'] + .named_transformers_['cat'] + .named_steps['onehot']) +feature_names = ohe.get_feature_names(input_features=categorical_columns) +feature_names = np.r_[feature_names, numerical_columns] + +tree_feature_importances = ( + rf.named_steps['classifier'].feature_importances_) +sorted_idx = tree_feature_importances.argsort() + +y_ticks = np.arange(0, len(feature_names)) +fig, ax = plt.subplots() +ax.barh(y_ticks, tree_feature_importances[sorted_idx]) +ax.set_yticklabels(feature_names[sorted_idx]) +ax.set_yticks(y_ticks) +ax.set_title("Random Forest Feature Importances (MDI)") +fig.tight_layout() +plt.show() + + +############################################################################## +# As an alternative, the permutation importances of ``rf`` are computed on a +# held out test set. This shows that the low cardinality categorical feature, +# ``sex`` is the most important feature. +# +# Also note that both random features have very low importances (close to 0) as +# expected. +result = permutation_importance(rf, X_test, y_test, n_repeats=10, + random_state=42, n_jobs=2) +sorted_idx = result.importances_mean.argsort() + +fig, ax = plt.subplots() +ax.boxplot(result.importances[sorted_idx].T, + vert=False, labels=X_test.columns[sorted_idx]) +ax.set_title("Permutation Importances (test set)") +fig.tight_layout() +plt.show() + +############################################################################## +# It is also possible to compute the permutation importances on the training +# set. This reveals that ``random_num`` gets a significantly higher importance +# ranking than when computed on the test set. The difference between those two +# plots is a confirmation that the RF model has enough capacity to use that +# random numerical feature to overfit. You can further confirm this by +# re-running this example with constrained RF with min_samples_leaf=10. +result = permutation_importance(rf, X_train, y_train, n_repeats=10, + random_state=42, n_jobs=2) +sorted_idx = result.importances_mean.argsort() + +fig, ax = plt.subplots() +ax.boxplot(result.importances[sorted_idx].T, + vert=False, labels=X_train.columns[sorted_idx]) +ax.set_title("Permutation Importances (train set)") +fig.tight_layout() +plt.show() diff --git a/examples/inspection/plot_permutation_importance_multicollinear.py b/examples/inspection/plot_permutation_importance_multicollinear.py new file mode 100644 index 0000000000000..460de614ed3b2 --- /dev/null +++ b/examples/inspection/plot_permutation_importance_multicollinear.py @@ -0,0 +1,111 @@ +""" +================================================================= +Permutation Importance with Multicollinear or Correlated Features +================================================================= + +In this example, we compute the permutation importance on the Wisconsin +breast cancer dataset using :func:`~sklearn.inspection.permutation_importance`. +The :class:`~sklearn.ensemble.RandomForestClassifier` can easily get about 97% +accuracy on a test dataset. Because this dataset contains multicollinear +features, the permutation importance will show that none of the features are +important. One approach to handling multicollinearity is by performing +hierarchical clustering on the features' Spearman rank-order correlations, +picking a threshold, and keeping a single feature from each cluster. + +.. note:: + See also + :ref:`sphx_glr_auto_examples_inspection_plot_permutation_importance.py` +""" +print(__doc__) +from collections import defaultdict + +import matplotlib.pyplot as plt +import numpy as np +from scipy.stats import spearmanr +from scipy.cluster import hierarchy + +from sklearn.datasets import load_breast_cancer +from sklearn.ensemble import RandomForestClassifier +from sklearn.inspection import permutation_importance +from sklearn.model_selection import train_test_split + +############################################################################## +# Random Forest Feature Importance on Breast Cancer Data +# ------------------------------------------------------ +# First, we train a random forest on the breast cancer dataset and evaluate +# its accuracy on a test set: +data = load_breast_cancer() +X, y = data.data, data.target +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + +clf = RandomForestClassifier(n_estimators=100, random_state=42) +clf.fit(X_train, y_train) +print("Accuracy on test data: {:.2f}".format(clf.score(X_test, y_test))) + +############################################################################## +# Next, we plot the tree based feature importance and the permutation +# importance. The permutation importance plot shows that permuting a feature +# drops the accuracy by at most `0.012`, which would suggest that none of the +# features are important. This is in contradiction with the high test accuracy +# computed above: some feature must be important. The permutation importance +# is calculated on the training set to show how much the model relies on each +# feature during training. +result = permutation_importance(clf, X_train, y_train, n_repeats=10, + random_state=42) +perm_sorted_idx = result.importances_mean.argsort() + +tree_importance_sorted_idx = np.argsort(clf.feature_importances_) +tree_indicies = np.arange(1, len(clf.feature_importances_) + 1) + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) +ax1.barh(tree_indicies, clf.feature_importances_[tree_importance_sorted_idx]) +ax1.set_yticklabels(data.feature_names) +ax1.set_yticks(tree_indicies) +ax2.boxplot(result.importances[perm_sorted_idx].T, vert=False, + labels=data.feature_names) +fig.tight_layout() +plt.show() + +############################################################################## +# Handling Multicollinear Features +# -------------------------------- +# When features are collinear, permutating one feature will have little +# effect on the models performance because it can get the same information +# from a correlated feature. One way to handle multicollinear features is by +# performing hierarchical clustering on the Spearman rank-order correlations, +# picking a threshold, and keeping a single feature from each cluster. First, +# we plot a heatmap of the correlated features: +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) +corr = spearmanr(X).correlation +corr_linkage = hierarchy.ward(corr) +dendro = hierarchy.dendrogram(corr_linkage, labels=data.feature_names, ax=ax1, + leaf_rotation=90) +dendro_idx = np.arange(0, len(dendro['ivl'])) + +ax2.imshow(corr[dendro['leaves'], :][:, dendro['leaves']]) +ax2.set_xticks(dendro_idx) +ax2.set_yticks(dendro_idx) +ax2.set_xticklabels(dendro['ivl'], rotation='vertical') +ax2.set_yticklabels(dendro['ivl']) +fig.tight_layout() +plt.show() + +############################################################################## +# Next, we manually pick a threshold by visual inspection of the dendrogram +# to group our features into clusters and choose a feature from each cluster to +# keep, select those features from our dataset, and train a new random forest. +# The test accuracy of the new random forest did not change much compared to +# the random forest trained on the complete dataset. +cluster_ids = hierarchy.fcluster(corr_linkage, 1, criterion='distance') +cluster_id_to_feature_ids = defaultdict(list) +for idx, cluster_id in enumerate(cluster_ids): + cluster_id_to_feature_ids[cluster_id].append(idx) +selected_features = [v[0] for v in cluster_id_to_feature_ids.values()] + +X_train_sel = X_train[:, selected_features] +X_test_sel = X_test[:, selected_features] + +clf_sel = RandomForestClassifier(n_estimators=100, random_state=42) +clf_sel.fit(X_train_sel, y_train) +print("Accuracy on test data with features removed: {:.2f}".format( + clf_sel.score(X_test_sel, y_test))) diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index 2bf3fe14c0023..6670e4c576c4d 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -1,9 +1,10 @@ """The :mod:`sklearn.inspection` module includes tools for model inspection.""" from .partial_dependence import partial_dependence from .partial_dependence import plot_partial_dependence - +from .permutation_importance import permutation_importance __all__ = [ 'partial_dependence', 'plot_partial_dependence', + 'permutation_importance' ] diff --git a/sklearn/inspection/permutation_importance.py b/sklearn/inspection/permutation_importance.py new file mode 100644 index 0000000000000..8f63a6c000a36 --- /dev/null +++ b/sklearn/inspection/permutation_importance.py @@ -0,0 +1,126 @@ +"""Permutation importance for estimators""" +import numpy as np +from joblib import Parallel +from joblib import delayed + +from ..metrics import check_scoring +from ..utils import check_random_state +from ..utils import check_array +from ..utils import Bunch + + +def _safe_column_setting(X, col_idx, values): + """Set column on X using `col_idx`""" + if hasattr(X, "iloc"): + X.iloc[:, col_idx] = values + else: + X[:, col_idx] = values + + +def _safe_column_indexing(X, col_idx): + """Return column from X using `col_idx`""" + if hasattr(X, "iloc"): + return X.iloc[:, col_idx].values + else: + return X[:, col_idx] + + +def _calculate_permutation_scores(estimator, X, y, col_idx, random_state, + n_repeats, scorer): + """Calculate score when `col_idx` is permuted.""" + original_feature = _safe_column_indexing(X, col_idx).copy() + temp = original_feature.copy() + + scores = np.zeros(n_repeats) + for n_round in range(n_repeats): + random_state.shuffle(temp) + _safe_column_setting(X, col_idx, temp) + feature_score = scorer(estimator, X, y) + scores[n_round] = feature_score + + _safe_column_setting(X, col_idx, original_feature) + return scores + + +def permutation_importance(estimator, X, y, scoring=None, n_repeats=5, + n_jobs=None, random_state=None): + """Permutation importance for feature evaluation [BRE]_. + + The `estimator` is required to be a fitted estimator. `X` can be the + data set used to train the estimator or a hold-out set. The permutation + importance of a feature is calculated as follows. First, a baseline metric, + defined by `scoring`, is evaluated on a (potentially different) dataset + defined by the `X`. Next, a feature column from the validation set is + permuted and the metric is evaluated again. The permutation importance is + defined to be the difference between the baseline metric and metric from + permutating the feature column. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : object + An estimator that has already been `fitted` and is compatible with + `scorer`. + + X : ndarray or DataFrame, shape (n_samples, n_features) + Data on which permutation importance will be computed. + + y : array-like or None, shape (n_samples, ) or (n_samples, n_classes) + Targets for supervised or `None` for unsupervised. + + scoring : string, callable or None, default=None + Scorer to use. It can be a single + string (see :ref:`scoring_parameter`) or a callable (see + :ref:`scoring`). If None, the estimator's default scorer is used. + + n_repeats : int, default=5 + Number of times to permute a feature. + + n_jobs : int or None, default=None + The number of jobs to use for the computation. + `None` means 1 unless in a :obj:`joblib.parallel_backend` context. + `-1` means using all processors. See :term:`Glossary ` + for more details. + + random_state : int, RandomState instance, or None, default=None + Pseudo-random number generator to control the permutations of each + feature. See :term:`random_state`. + + Returns + ------- + result : Bunch + Dictionary-like object, with attributes: + + importances_mean : ndarray, shape (n_features, ) + Mean of feature importance over `n_repeats`. + importances_std : ndarray, shape (n_features, ) + Standard deviation over `n_repeats`. + importances : ndarray, shape (n_features, n_repeats) + Raw permutation importance scores. + + References + ---------- + .. [BRE] L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, + 2001. https://doi.org/10.1023/A:1010933404324 + """ + if hasattr(X, "iloc"): + X = X.copy() # Dataframe + else: + X = check_array(X, force_all_finite='allow-nan', dtype=np.object, + copy=True) + + random_state = check_random_state(random_state) + scorer = check_scoring(estimator, scoring=scoring) + + baseline_score = scorer(estimator, X, y) + scores = np.zeros((X.shape[1], n_repeats)) + + scores = Parallel(n_jobs=n_jobs)(delayed(_calculate_permutation_scores)( + estimator, X, y, col_idx, random_state, n_repeats, scorer + ) for col_idx in range(X.shape[1])) + + importances = baseline_score - np.array(scores) + return Bunch(importances_mean=np.mean(importances, axis=1), + importances_std=np.std(importances, axis=1), + importances=importances) diff --git a/sklearn/inspection/tests/test_permutation_importance.py b/sklearn/inspection/tests/test_permutation_importance.py new file mode 100644 index 0000000000000..9394202cfce97 --- /dev/null +++ b/sklearn/inspection/tests/test_permutation_importance.py @@ -0,0 +1,154 @@ +import pytest +import numpy as np + +from numpy.testing import assert_allclose + +from sklearn.compose import ColumnTransformer +from sklearn.datasets import load_boston +from sklearn.datasets import load_iris +from sklearn.datasets import make_regression +from sklearn.ensemble import RandomForestRegressor +from sklearn.ensemble import RandomForestClassifier +from sklearn.linear_model import LinearRegression +from sklearn.linear_model import LogisticRegression +from sklearn.impute import SimpleImputer +from sklearn.inspection import permutation_importance +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import OneHotEncoder +from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import scale + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_permutation_importance_correlated_feature_regression(n_jobs): + # Make sure that feature highly correlated to the target have a higher + # importance + rng = np.random.RandomState(42) + n_repeats = 5 + + dataset = load_boston() + X, y = dataset.data, dataset.target + y_with_little_noise = ( + y + rng.normal(scale=0.001, size=y.shape[0])).reshape(-1, 1) + + X = np.hstack([X, y_with_little_noise]) + + clf = RandomForestRegressor(n_estimators=10, random_state=42) + clf.fit(X, y) + + result = permutation_importance(clf, X, y, n_repeats=n_repeats, + random_state=rng, n_jobs=n_jobs) + + assert result.importances.shape == (X.shape[1], n_repeats) + + # the correlated feature with y was added as the last column and should + # have the highest importance + assert np.all(result.importances_mean[-1] > + result.importances_mean[:-1]) + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_permutation_importance_correlated_feature_regression_pandas(n_jobs): + pd = pytest.importorskip("pandas") + + # Make sure that feature highly correlated to the target have a higher + # importance + rng = np.random.RandomState(42) + n_repeats = 5 + + dataset = load_iris() + X, y = dataset.data, dataset.target + y_with_little_noise = ( + y + rng.normal(scale=0.001, size=y.shape[0])).reshape(-1, 1) + + # Adds feature correlated with y as the last column + X = pd.DataFrame(X, columns=dataset.feature_names) + X['correlated_feature'] = y_with_little_noise + + clf = RandomForestClassifier(n_estimators=10, random_state=42) + clf.fit(X, y) + + result = permutation_importance(clf, X, y, n_repeats=n_repeats, + random_state=rng, n_jobs=n_jobs) + + assert result.importances.shape == (X.shape[1], n_repeats) + + # the correlated feature with y was added as the last column and should + # have the highest importance + assert np.all(result.importances_mean[-1] > result.importances_mean[:-1]) + + +def test_permutation_importance_mixed_types(): + rng = np.random.RandomState(42) + n_repeats = 4 + + # Last column is correlated with y + X = np.array([[1.0, 2.0, 3.0, np.nan], [2, 1, 2, 1]]).T + y = np.array([0, 1, 0, 1]) + + clf = make_pipeline(SimpleImputer(), LogisticRegression(solver='lbfgs')) + clf.fit(X, y) + result = permutation_importance(clf, X, y, n_repeats=n_repeats, + random_state=rng) + + assert result.importances.shape == (X.shape[1], n_repeats) + + # the correlated feature with y is the last column and should + # have the highest importance + assert np.all(result.importances_mean[-1] > result.importances_mean[:-1]) + + # use another random state + rng = np.random.RandomState(0) + result2 = permutation_importance(clf, X, y, n_repeats=n_repeats, + random_state=rng) + assert result2.importances.shape == (X.shape[1], n_repeats) + + assert not np.allclose(result.importances, result2.importances) + + # the correlated feature with y is the last column and should + # have the highest importance + assert np.all(result2.importances_mean[-1] > result2.importances_mean[:-1]) + + +def test_permutation_importance_mixed_types_pandas(): + pd = pytest.importorskip("pandas") + rng = np.random.RandomState(42) + n_repeats = 5 + + # Last column is correlated with y + X = pd.DataFrame({'col1': [1.0, 2.0, 3.0, np.nan], + 'col2': ['a', 'b', 'a', 'b']}) + y = np.array([0, 1, 0, 1]) + + num_preprocess = make_pipeline(SimpleImputer(), StandardScaler()) + preprocess = ColumnTransformer([ + ('num', num_preprocess, ['col1']), + ('cat', OneHotEncoder(), ['col2']) + ]) + clf = make_pipeline(preprocess, LogisticRegression(solver='lbfgs')) + clf.fit(X, y) + + result = permutation_importance(clf, X, y, n_repeats=n_repeats, + random_state=rng) + + assert result.importances.shape == (X.shape[1], n_repeats) + # the correlated feature with y is the last column and should + # have the highest importance + assert np.all(result.importances_mean[-1] > result.importances_mean[:-1]) + + +def test_permutation_importance_linear_regresssion(): + X, y = make_regression(n_samples=500, n_features=10, random_state=0) + + X = scale(X) + y = scale(y) + + lr = LinearRegression().fit(X, y) + + # this relationship can be computed in closed form + expected_importances = 2 * lr.coef_**2 + results = permutation_importance(lr, X, y, + n_repeats=50, + scoring='neg_mean_squared_error') + assert_allclose(expected_importances, results.importances_mean, + rtol=1e-1, atol=1e-6)