Skip to content

Commit 2af3fb8

Browse files
FEA Add metadata routing to SelectFromModel (#27490)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent aafe31f commit 2af3fb8

File tree

4 files changed

+91
-15
lines changed

4 files changed

+91
-15
lines changed

doc/metadata_routing.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ Meta-estimators and functions supporting metadata routing:
252252

253253
- :class:`sklearn.calibration.CalibratedClassifierCV`
254254
- :class:`sklearn.compose.ColumnTransformer`
255+
- :class:`sklearn.feature_selection.SelectFromModel`
255256
- :class:`sklearn.linear_model.ElasticNetCV`
256257
- :class:`sklearn.linear_model.LarsCV`
257258
- :class:`sklearn.linear_model.LassoCV`
@@ -290,7 +291,6 @@ Meta-estimators and tools not supporting metadata routing yet:
290291
- :class:`sklearn.ensemble.VotingRegressor`
291292
- :class:`sklearn.feature_selection.RFE`
292293
- :class:`sklearn.feature_selection.RFECV`
293-
- :class:`sklearn.feature_selection.SelectFromModel`
294294
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
295295
- :class:`sklearn.impute.IterativeImputer`
296296
- :class:`sklearn.linear_model.RANSACRegressor`

doc/whats_new/v1.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ more details.
107107
``**score_params`` which are passed to the underlying scorer.
108108
:pr:`26525` by :user:`Omar Salman <OmarManzoor>`.
109109

110+
- |Feature| :class:`feature_selection.SelectFromModel` now supports metadata
111+
routing in `fit` and `partial_fit`.
112+
:pr:`27490` by :user:`Stefanie Senger <StefanieSenger>`.
113+
110114
- |Feature| :class:`linear_model.OrthogonalMatchingPursuitCV` now supports
111115
metadata routing. Its `fit` now accepts ``**fit_params``, which are passed to
112116
the underlying splitter. :pr:`27500` by :user:`Stefanie Senger

sklearn/feature_selection/_from_model.py

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from ..utils._param_validation import HasMethods, Interval, Options
1212
from ..utils._tags import _safe_tags
1313
from ..utils.metadata_routing import (
14-
_raise_for_unsupported_routing,
15-
_RoutingNotSupportedMixin,
14+
MetadataRouter,
15+
MethodMapping,
16+
_routing_enabled,
17+
process_routing,
1618
)
1719
from ..utils.metaestimators import available_if
1820
from ..utils.validation import _num_features, check_is_fitted, check_scalar
@@ -82,9 +84,7 @@ def _estimator_has(attr):
8284
)
8385

8486

85-
class SelectFromModel(
86-
_RoutingNotSupportedMixin, MetaEstimatorMixin, SelectorMixin, BaseEstimator
87-
):
87+
class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
8888
"""Meta-transformer for selecting features based on importance weights.
8989
9090
.. versionadded:: 0.17
@@ -341,14 +341,25 @@ def fit(self, X, y=None, **fit_params):
341341
classification, real numbers in regression).
342342
343343
**fit_params : dict
344-
Other estimator specific parameters.
344+
- If `enable_metadata_routing=False` (default):
345+
346+
Parameters directly passed to the `partial_fit` method of the
347+
sub-estimator. They are ignored if `prefit=True`.
348+
349+
- If `enable_metadata_routing=True`:
350+
351+
Parameters safely routed to the `partial_fit` method of the
352+
sub-estimator. They are ignored if `prefit=True`.
353+
354+
.. versionchanged:: 1.4
355+
See :ref:`Metadata Routing User Guide <metadata_routing>` for
356+
more details.
345357
346358
Returns
347359
-------
348360
self : object
349361
Fitted estimator.
350362
"""
351-
_raise_for_unsupported_routing(self, "fit", **fit_params)
352363
self._check_max_features(X)
353364

354365
if self.prefit:
@@ -361,8 +372,14 @@ def fit(self, X, y=None, **fit_params):
361372
) from exc
362373
self.estimator_ = deepcopy(self.estimator)
363374
else:
364-
self.estimator_ = clone(self.estimator)
365-
self.estimator_.fit(X, y, **fit_params)
375+
if _routing_enabled():
376+
routed_params = process_routing(self, "fit", **fit_params)
377+
self.estimator_ = clone(self.estimator)
378+
self.estimator_.fit(X, y, **routed_params.estimator.fit)
379+
else:
380+
# TODO(SLEP6): remove when metadata routing cannot be disabled.
381+
self.estimator_ = clone(self.estimator)
382+
self.estimator_.fit(X, y, **fit_params)
366383

367384
if hasattr(self.estimator_, "feature_names_in_"):
368385
self.feature_names_in_ = self.estimator_.feature_names_in_
@@ -387,7 +404,7 @@ def threshold_(self):
387404
# SelectFromModel.estimator is not validated yet
388405
prefer_skip_nested_validation=False
389406
)
390-
def partial_fit(self, X, y=None, **fit_params):
407+
def partial_fit(self, X, y=None, **partial_fit_params):
391408
"""Fit the SelectFromModel meta-transformer only once.
392409
393410
Parameters
@@ -399,8 +416,24 @@ def partial_fit(self, X, y=None, **fit_params):
399416
The target values (integers that correspond to classes in
400417
classification, real numbers in regression).
401418
402-
**fit_params : dict
403-
Other estimator specific parameters.
419+
**partial_fit_params : dict
420+
- If `enable_metadata_routing=False` (default):
421+
422+
Parameters directly passed to the `partial_fit` method of the
423+
sub-estimator.
424+
425+
- If `enable_metadata_routing=True`:
426+
427+
Parameters passed to the `partial_fit` method of the
428+
sub-estimator. They are ignored if `prefit=True`.
429+
430+
.. versionchanged:: 1.4
431+
`**partial_fit_params` are routed to the sub-estimator, if
432+
`enable_metadata_routing=True` is set via
433+
:func:`~sklearn.set_config`, which allows for aliasing.
434+
435+
See :ref:`Metadata Routing User Guide <metadata_routing>` for
436+
more details.
404437
405438
Returns
406439
-------
@@ -426,7 +459,13 @@ def partial_fit(self, X, y=None, **fit_params):
426459

427460
if first_call:
428461
self.estimator_ = clone(self.estimator)
429-
self.estimator_.partial_fit(X, y, **fit_params)
462+
if _routing_enabled():
463+
routed_params = process_routing(self, "partial_fit", **partial_fit_params)
464+
self.estimator_ = clone(self.estimator)
465+
self.estimator_.partial_fit(X, y, **routed_params.estimator.partial_fit)
466+
else:
467+
# TODO(SLEP6): remove when metadata routing cannot be disabled.
468+
self.estimator_.partial_fit(X, y, **partial_fit_params)
430469

431470
if hasattr(self.estimator_, "feature_names_in_"):
432471
self.feature_names_in_ = self.estimator_.feature_names_in_
@@ -451,5 +490,27 @@ def n_features_in_(self):
451490

452491
return self.estimator_.n_features_in_
453492

493+
def get_metadata_routing(self):
494+
"""Get metadata routing of this object.
495+
496+
Please check :ref:`User Guide <metadata_routing>` on how the routing
497+
mechanism works.
498+
499+
.. versionadded:: 1.4
500+
501+
Returns
502+
-------
503+
routing : MetadataRouter
504+
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
505+
routing information.
506+
"""
507+
router = MetadataRouter(owner=self.__class__.__name__).add(
508+
estimator=self.estimator,
509+
method_mapping=MethodMapping()
510+
.add(callee="partial_fit", caller="partial_fit")
511+
.add(callee="fit", caller="fit"),
512+
)
513+
return router
514+
454515
def _more_tags(self):
455516
return {"allow_nan": _safe_tags(self.estimator, key="allow_nan")}

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66

77
from sklearn import config_context
8+
from sklearn.base import is_classifier
89
from sklearn.calibration import CalibratedClassifierCV
910
from sklearn.compose import TransformedTargetRegressor
1011
from sklearn.covariance import GraphicalLassoCV
@@ -228,6 +229,15 @@ def enable_slep006():
228229
"y": y,
229230
"estimator_routing_methods": ["fit"],
230231
},
232+
{
233+
"metaestimator": SelectFromModel,
234+
"estimator_name": "estimator",
235+
"estimator": ConsumingClassifier,
236+
"X": X,
237+
"y": y,
238+
"estimator_routing_methods": ["fit", "partial_fit"],
239+
"method_args": {"partial_fit": {"classes": classes}},
240+
},
231241
{
232242
"metaestimator": OrthogonalMatchingPursuitCV,
233243
"X": X,
@@ -325,7 +335,6 @@ def enable_slep006():
325335
RFECV(ConsumingClassifier()),
326336
RidgeCV(),
327337
RidgeClassifierCV(),
328-
SelectFromModel(ConsumingClassifier()),
329338
SelfTrainingClassifier(ConsumingClassifier()),
330339
SequentialFeatureSelector(ConsumingClassifier()),
331340
StackingClassifier(ConsumingClassifier()),
@@ -477,6 +486,8 @@ def set_request(estimator, method_name):
477486
# e.g. call set_fit_request on estimator
478487
set_request_for_method = getattr(estimator, f"set_{method_name}_request")
479488
set_request_for_method(sample_weight=True, metadata=True)
489+
if is_classifier(estimator) and method_name == "partial_fit":
490+
set_request_for_method(classes=True)
480491

481492
cls = metaestimator["metaestimator"]
482493
X = metaestimator["X"]

0 commit comments

Comments
 (0)