Skip to content

Commit e90dc9d

Browse files
lorentzbaothomasjpfanglemaitre
committed
FIX partial_fit from SelectFromModel doesn't validate the parameters (#23299)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 9d4c287 commit e90dc9d

File tree

3 files changed

+80
-25
lines changed

3 files changed

+80
-25
lines changed

sklearn/feature_selection/_from_model.py

+38-25
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
from ._base import _get_feature_importances
1111
from ..base import BaseEstimator, clone, MetaEstimatorMixin
1212
from ..utils._tags import _safe_tags
13-
from ..utils.validation import check_is_fitted
13+
from ..utils.validation import check_is_fitted, check_scalar, _num_features
1414

1515
from ..exceptions import NotFittedError
1616
from ..utils.metaestimators import available_if
17-
from ..utils.validation import check_scalar
1817

1918

2019
def _calculate_threshold(estimator, importances, threshold):
@@ -287,34 +286,17 @@ def _get_support_mask(self):
287286
mask[scores < threshold] = False
288287
return mask
289288

290-
def fit(self, X, y=None, **fit_params):
291-
"""Fit the SelectFromModel meta-transformer.
292-
293-
Parameters
294-
----------
295-
X : array-like of shape (n_samples, n_features)
296-
The training input samples.
297-
298-
y : array-like of shape (n_samples,), default=None
299-
The target values (integers that correspond to classes in
300-
classification, real numbers in regression).
301-
302-
**fit_params : dict
303-
Other estimator specific parameters.
304-
305-
Returns
306-
-------
307-
self : object
308-
Fitted estimator.
309-
"""
289+
def _check_max_features(self, X):
310290
if self.max_features is not None:
291+
n_features = _num_features(X)
292+
311293
if isinstance(self.max_features, numbers.Integral):
312294
check_scalar(
313295
self.max_features,
314296
"max_features",
315297
numbers.Integral,
316298
min_val=0,
317-
max_val=len(X[0]),
299+
max_val=n_features,
318300
)
319301
self.max_features_ = self.max_features
320302
elif callable(self.max_features):
@@ -324,7 +306,7 @@ def fit(self, X, y=None, **fit_params):
324306
"max_features(X)",
325307
numbers.Integral,
326308
min_val=0,
327-
max_val=len(X[0]),
309+
max_val=n_features,
328310
)
329311
self.max_features_ = max_features
330312
else:
@@ -333,6 +315,28 @@ def fit(self, X, y=None, **fit_params):
333315
f" 'X' as input. Got {self.max_features} instead."
334316
)
335317

318+
def fit(self, X, y=None, **fit_params):
319+
"""Fit the SelectFromModel meta-transformer.
320+
321+
Parameters
322+
----------
323+
X : array-like of shape (n_samples, n_features)
324+
The training input samples.
325+
326+
y : array-like of shape (n_samples,), default=None
327+
The target values (integers that correspond to classes in
328+
classification, real numbers in regression).
329+
330+
**fit_params : dict
331+
Other estimator specific parameters.
332+
333+
Returns
334+
-------
335+
self : object
336+
Fitted estimator.
337+
"""
338+
self._check_max_features(X)
339+
336340
if self.prefit:
337341
try:
338342
check_is_fitted(self.estimator)
@@ -385,6 +389,8 @@ def partial_fit(self, X, y=None, **fit_params):
385389
self : object
386390
Fitted estimator.
387391
"""
392+
self._check_max_features(X)
393+
388394
if self.prefit:
389395
if not hasattr(self, "estimator_"):
390396
try:
@@ -397,9 +403,16 @@ def partial_fit(self, X, y=None, **fit_params):
397403
self.estimator_ = deepcopy(self.estimator)
398404
return self
399405

400-
if not hasattr(self, "estimator_"):
406+
first_call = not hasattr(self, "estimator_")
407+
if first_call:
401408
self.estimator_ = clone(self.estimator)
402409
self.estimator_.partial_fit(X, y, **fit_params)
410+
411+
if hasattr(self.estimator_, "feature_names_in_"):
412+
self.feature_names_in_ = self.estimator_.feature_names_in_
413+
else:
414+
self._check_feature_names(X, reset=first_call)
415+
403416
return self
404417

405418
@property

sklearn/feature_selection/tests/test_from_model.py

+37
Original file line numberDiff line numberDiff line change
@@ -607,3 +607,40 @@ def importance_getter(estimator):
607607
warnings.simplefilter("error", UserWarning)
608608

609609
selector.transform(X.iloc[1:3])
610+
611+
612+
@pytest.mark.parametrize(
613+
"error, err_msg, max_features",
614+
(
615+
[ValueError, "max_features == 10, must be <= 4", 10],
616+
[TypeError, "'max_features' must be either an int or a callable", "a"],
617+
[ValueError, r"max_features\(X\) == 5, must be <= 4", lambda x: x.shape[1] + 1],
618+
),
619+
)
620+
def test_partial_fit_validate_max_features(error, err_msg, max_features):
621+
"""Test that partial_fit from SelectFromModel validates `max_features`."""
622+
X, y = datasets.make_classification(
623+
n_samples=100,
624+
n_features=4,
625+
random_state=0,
626+
)
627+
628+
with pytest.raises(error, match=err_msg):
629+
SelectFromModel(
630+
estimator=SGDClassifier(), max_features=max_features
631+
).partial_fit(X, y, classes=[0, 1])
632+
633+
634+
@pytest.mark.parametrize("as_frame", [True, False])
635+
def test_partial_fit_validate_feature_names(as_frame):
636+
"""Test that partial_fit from SelectFromModel validates `feature_names_in_`."""
637+
pytest.importorskip("pandas")
638+
X, y = datasets.load_iris(as_frame=as_frame, return_X_y=True)
639+
640+
selector = SelectFromModel(estimator=SGDClassifier(), max_features=4).partial_fit(
641+
X, y, classes=[0, 1, 2]
642+
)
643+
if as_frame:
644+
assert_array_equal(selector.feature_names_in_, X.columns)
645+
else:
646+
assert not hasattr(selector, "feature_names_in_")

sklearn/utils/estimator_checks.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..linear_model import LogisticRegression
3232
from ..linear_model import RANSACRegressor
3333
from ..linear_model import Ridge
34+
from ..linear_model import SGDRegressor
3435

3536
from ..base import (
3637
clone,
@@ -44,6 +45,7 @@
4445
from ..metrics import accuracy_score, adjusted_rand_score, f1_score
4546
from ..random_projection import BaseRandomProjection
4647
from ..feature_selection import SelectKBest
48+
from ..feature_selection import SelectFromModel
4749
from ..pipeline import make_pipeline
4850
from ..exceptions import DataConversionWarning
4951
from ..exceptions import NotFittedError
@@ -389,6 +391,9 @@ def _construct_instance(Estimator):
389391
estimator = Estimator(LinearRegression())
390392
elif issubclass(Estimator, RegressorMixin):
391393
estimator = Estimator(Ridge())
394+
elif issubclass(Estimator, SelectFromModel):
395+
# Increases coverage because SGDRegressor has partial_fit
396+
estimator = Estimator(SGDRegressor(random_state=0))
392397
else:
393398
estimator = Estimator(LogisticRegression(C=1))
394399
elif required_parameters in (["estimators"],):

0 commit comments

Comments
 (0)