Skip to content

add feature_indices param to permutation_importance() #30005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
22 changes: 21 additions & 1 deletion sklearn/inspection/_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _create_importances_bunch(baseline_score, permuted_score):
Interval(Integral, 1, None, closed="left"),
Interval(RealNotInt, 0, 1, closed="right"),
],
"feature_indices": ["array-like", None],
},
prefer_skip_nested_validation=True,
)
Expand All @@ -146,6 +147,7 @@ def permutation_importance(
random_state=None,
sample_weight=None,
max_samples=1.0,
feature_indices=None,
):
"""Permutation importance for feature evaluation [BRE]_.

Expand Down Expand Up @@ -230,6 +232,11 @@ def permutation_importance(

.. versionadded:: 1.0

feature_indices : a list of feature indices to calculate permutation importance
for only these features. If None, all features will be used.

.. versionadded:: 1.7

Returns
-------
result : :class:`~sklearn.utils.Bunch` or dict of such instances
Expand Down Expand Up @@ -277,6 +284,19 @@ def permutation_importance(
random_state = check_random_state(random_state)
random_seed = random_state.randint(np.iinfo(np.int32).max + 1)

if feature_indices is not None:
feature_indices = check_array(
feature_indices, ensure_2d=False, ensure_min_features=0, dtype=None
)
if feature_indices.ndim != 1:
raise ValueError("feature_indices must be 1D array-like")
if not np.issubdtype(feature_indices.dtype, np.integer):
raise ValueError("feature_indices must be array-like of integers")
if np.any(feature_indices < 0) or np.any(feature_indices >= X.shape[1]):
raise ValueError("feature_indices must be within [0, n_features]")
else:
feature_indices = np.arange(X.shape[1])

if not isinstance(max_samples, numbers.Integral):
max_samples = int(max_samples * X.shape[0])
elif max_samples > X.shape[0]:
Expand All @@ -297,7 +317,7 @@ def permutation_importance(
scorer,
max_samples,
)
for col_idx in range(X.shape[1])
for col_idx in feature_indices
)

if isinstance(baseline_score, dict):
Expand Down
41 changes: 41 additions & 0 deletions sklearn/inspection/tests/test_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,44 @@ def test_permutation_importance_max_samples_error():

with pytest.raises(ValueError, match=err_msg):
permutation_importance(clf, X, y, max_samples=5)


def test_permutation_importance_feature_indices():
"""Check you only get back len(feature_indices) importances."""
X, y = make_regression(n_samples=500, n_features=10, random_state=0)

lr = LinearRegression().fit(X, y)

n_repeats_test = 10
feature_indices_test = [0, 1, 2, 3, 9]

results = permutation_importance(
lr,
X,
y,
n_repeats=n_repeats_test,
scoring="neg_mean_squared_error",
feature_indices=feature_indices_test,
)
assert results.importances.shape == (len(feature_indices_test),n_repeats_test)


def test_permutation_importance_feature_indices_out_of_range_error():
"""Check you get error when feature_indices are out of range."""
n_features = 10

X, y = make_regression(n_samples=500, n_features=n_features, random_state=0)

lr = LinearRegression().fit(X, y)

feature_indices_test = range(n_features + 1)

err_msg = r"feature_indices must be within [0, n_features)"

with pytest.raises(ValueError, match=err_msg):
permutation_importance(
lr,
X,
y,
feature_indices=feature_indices_test,
)
Loading