Skip to content

FEAT rfecv: add support and ranking for each cv and step #30179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6fa6143
feat(rfecv): add support and ranking for each cv and step
MarieSacksick Oct 30, 2024
9897987
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Nov 4, 2024
714ac24
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Nov 26, 2024
701c098
add changelog
MarieSacksick Nov 26, 2024
e8e8ead
chore: precise version apparition
MarieSacksick Nov 26, 2024
0964620
chore: indent correctly version change
MarieSacksick Nov 26, 2024
6d8ac68
feat(rfecv): change attribute names & fix tests
MarieSacksick Dec 24, 2024
61417d8
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Dec 24, 2024
48cf76e
feat(rfecv): fix tests
MarieSacksick Dec 24, 2024
8b075cd
enrich docstring
MarieSacksick Feb 5, 2025
0d556b0
oupsy erase force
MarieSacksick Feb 5, 2025
6480fd1
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Feb 5, 2025
c1a3de1
chore: Write docstring for test function
MarieSacksick Feb 12, 2025
886f4df
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Feb 12, 2025
a271249
[doc build] enrich the rfecv example with the new support system
MarieSacksick Mar 18, 2025
d25ee14
chore: change github handle
MarieSacksick Mar 18, 2025
118798e
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Mar 19, 2025
87db900
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Mar 27, 2025
bb456c7
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
MarieSacksick Apr 2, 2025
179fba3
Merge branch 'main' into 17782_rfecv_provide_support_and_ranking
adrinjalali Apr 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :class:`feature_selection.RFECV` now gives access to the ranking and support in each
iteration and cv step of feature selection.
By :user:`Marie S. <MarieSacksick>`
26 changes: 24 additions & 2 deletions examples/feature_selection/plot_rfe_with_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@

from sklearn.datasets import make_classification

n_features = 15
feat_names = [f"feature_{i}" for i in range(15)]

X, y = make_classification(
n_samples=500,
n_features=15,
n_features=n_features,
n_informative=3,
n_redundant=2,
n_repeated=0,
Expand Down Expand Up @@ -71,7 +74,12 @@
import matplotlib.pyplot as plt
import pandas as pd

cv_results = pd.DataFrame(rfecv.cv_results_)
data = {
key: value
for key, value in rfecv.cv_results_.items()
if key in ["n_features", "mean_test_score", "std_test_score"]
}
cv_results = pd.DataFrame(data)
plt.figure()
plt.xlabel("Number of features selected")
plt.ylabel("Mean test accuracy")
Expand All @@ -91,3 +99,17 @@
# cross-validation technique. The test accuracy decreases above 5 selected
# features, this is, keeping non-informative features leads to over-fitting and
# is therefore detrimental for the statistical performance of the models.

# %%
import numpy as np

for i in range(cv.n_splits):
mask = rfecv.cv_results_[f"split{i}_support"][
rfecv.n_features_
] # mask of features selected by the RFE
features_selected = np.ma.compressed(np.ma.masked_array(feat_names, mask=1 - mask))
print(f"Features selected in fold {i}: {features_selected}")
# %%
# In the five folds, the selected features are consistant. This is good news,
# it means that the selection is stable accross folds, and it confirms that
# these features are the most informative ones.
46 changes: 36 additions & 10 deletions sklearn/feature_selection/_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer, routed_params):
**fit_params,
)

return rfe.step_scores_, rfe.step_n_features_
return rfe.step_scores_, rfe.step_support_, rfe.step_ranking_, rfe.step_n_features_


class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
Expand Down Expand Up @@ -318,6 +318,8 @@ def _fit(self, X, y, step_score=None, **fit_params):
if step_score:
self.step_n_features_ = []
self.step_scores_ = []
self.step_support_ = []
self.step_ranking_ = []

# Elimination
while np.sum(support_) > n_features_to_select:
Expand All @@ -331,6 +333,14 @@ def _fit(self, X, y, step_score=None, **fit_params):

estimator.fit(X[:, features], y, **fit_params)

# Compute step values on the previous selection iteration because
# 'estimator' must use features that have not been eliminated yet
if step_score:
self.step_n_features_.append(len(features))
self.step_scores_.append(step_score(estimator, features))
self.step_support_.append(list(support_))
self.step_ranking_.append(list(ranking_))

# Get importance and rank them
importances = _get_feature_importances(
estimator,
Expand All @@ -345,12 +355,6 @@ def _fit(self, X, y, step_score=None, **fit_params):
# Eliminate the worse features
threshold = min(step, np.sum(support_) - n_features_to_select)

# Compute step score on the previous selection iteration
# because 'estimator' must use features
# that have not been eliminated yet
if step_score:
self.step_n_features_.append(len(features))
self.step_scores_.append(step_score(estimator, features))
support_[features[ranks][:threshold]] = False
ranking_[np.logical_not(support_)] += 1

Expand All @@ -359,10 +363,12 @@ def _fit(self, X, y, step_score=None, **fit_params):
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X[:, features], y, **fit_params)

# Compute step score when only n_features_to_select features left
# Compute step values when only n_features_to_select features left
if step_score:
self.step_n_features_.append(len(features))
self.step_scores_.append(step_score(self.estimator_, features))
self.step_support_.append(support_)
self.step_ranking_.append(ranking_)
self.n_features_ = support_.sum()
self.support_ = support_
self.ranking_ = ranking_
Expand Down Expand Up @@ -674,6 +680,20 @@ class RFECV(RFE):

.. versionadded:: 1.5

split(k)_ranking : ndarray of shape (n_subsets_of_features,)
The cross-validation rankings across (k)th fold.
Selected (i.e., estimated best) features are assigned rank 1.
Illustration in
:ref:`sphx_glr_auto_examples_feature_selection_plot_rfe_with_cross_validation.py`

.. versionadded:: 1.7

split(k)_support : ndarray of shape (n_subsets_of_features,)
The cross-validation supports across (k)th fold. The support
is the mask of selected features.

.. versionadded:: 1.7

n_features_ : int
The number of selected features with cross-validation.

Expand Down Expand Up @@ -874,14 +894,16 @@ def fit(self, X, y, *, groups=None, **params):
parallel = Parallel(n_jobs=self.n_jobs)
func = delayed(_rfe_single_fit)

scores_features = parallel(
step_results = parallel(
func(clone(rfe), self.estimator, X, y, train, test, scorer, routed_params)
for train, test in cv.split(X, y, **routed_params.splitter.split)
)
scores, step_n_features = zip(*scores_features)
scores, supports, rankings, step_n_features = zip(*step_results)

step_n_features_rev = np.array(step_n_features[0])[::-1]
scores = np.array(scores)
rankings = np.array(rankings)
supports = np.array(supports)

# Reverse order such that lowest number of features is selected in case of tie.
scores_sum_rev = np.sum(scores, axis=0)[::-1]
Expand All @@ -907,10 +929,14 @@ def fit(self, X, y, *, groups=None, **params):

# reverse to stay consistent with before
scores_rev = scores[:, ::-1]
supports_rev = supports[:, ::-1]
rankings_rev = rankings[:, ::-1]
self.cv_results_ = {
"mean_test_score": np.mean(scores_rev, axis=0),
"std_test_score": np.std(scores_rev, axis=0),
**{f"split{i}_test_score": scores_rev[i] for i in range(scores.shape[0])},
**{f"split{i}_ranking": rankings_rev[i] for i in range(rankings.shape[0])},
**{f"split{i}_support": supports_rev[i] for i in range(supports.shape[0])},
"n_features": step_n_features_rev,
}
return self
Expand Down
34 changes: 33 additions & 1 deletion sklearn/feature_selection/tests/test_rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Testing Recursive feature elimination
"""

import re
from operator import attrgetter

import numpy as np
Expand Down Expand Up @@ -541,7 +542,11 @@ def test_rfecv_std_and_mean(global_random_seed):

rfecv = RFECV(estimator=SVC(kernel="linear"))
rfecv.fit(X, y)
split_keys = [key for key in rfecv.cv_results_.keys() if "split" in key]
split_keys = [
key
for key in rfecv.cv_results_.keys()
if re.search(r"split\d+_test_score", key)
]
cv_scores = np.asarray([rfecv.cv_results_[key] for key in split_keys])
expected_mean = np.mean(cv_scores, axis=0)
expected_std = np.std(cv_scores, axis=0)
Expand Down Expand Up @@ -721,3 +726,30 @@ def test_rfe_with_joblib_threading_backend(global_random_seed):
rfe.fit(X, y)

assert_array_equal(ranking_ref, rfe.ranking_)


def test_results_per_cv_in_rfecv(global_random_seed):
"""
Test that the results of RFECV are consistent across the different folds
in terms of length of the arrays.
"""
X, y = make_classification(random_state=global_random_seed)

clf = LogisticRegression()
rfecv = RFECV(
estimator=clf,
n_jobs=2,
cv=5,
)

rfecv.fit(X, y)

assert len(rfecv.cv_results_["split1_test_score"]) == len(
rfecv.cv_results_["split2_test_score"]
)
assert len(rfecv.cv_results_["split1_support"]) == len(
rfecv.cv_results_["split2_support"]
)
assert len(rfecv.cv_results_["split1_ranking"]) == len(
rfecv.cv_results_["split2_ranking"]
)