Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ Meta-estimators and functions supporting metadata routing:

- :class:`sklearn.calibration.CalibratedClassifierCV`
- :class:`sklearn.compose.ColumnTransformer`
- :class:`sklearn.feature_selection.SelectFromModel`
- :class:`sklearn.linear_model.ElasticNetCV`
- :class:`sklearn.linear_model.LarsCV`
- :class:`sklearn.linear_model.LassoCV`
Expand Down Expand Up @@ -290,7 +291,6 @@ Meta-estimators and tools not supporting metadata routing yet:
- :class:`sklearn.ensemble.VotingRegressor`
- :class:`sklearn.feature_selection.RFE`
- :class:`sklearn.feature_selection.RFECV`
- :class:`sklearn.feature_selection.SelectFromModel`
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
- :class:`sklearn.impute.IterativeImputer`
- :class:`sklearn.linear_model.RANSACRegressor`
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ more details.
``**score_params`` which are passed to the underlying scorer.
:pr:`26525` by :user:`Omar Salman <OmarManzoor>`.

- |Feature| :class:`feature_selection.SelectFromModel` now supports metadata
routing in `fit` and `partial_fit`.
:pr:`27490` by :user:`Stefanie Senger <StefanieSenger>`.

- |Feature| :class:`linear_model.OrthogonalMatchingPursuitCV` now supports
metadata routing. Its `fit` now accepts ``**fit_params``, which are passed to
the underlying splitter. :pr:`27500` by :user:`Stefanie Senger
Expand Down
87 changes: 74 additions & 13 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from ..utils._param_validation import HasMethods, Interval, Options
from ..utils._tags import _safe_tags
from ..utils.metadata_routing import (
_raise_for_unsupported_routing,
_RoutingNotSupportedMixin,
MetadataRouter,
MethodMapping,
_routing_enabled,
process_routing,
)
from ..utils.metaestimators import available_if
from ..utils.validation import _num_features, check_is_fitted, check_scalar
Expand Down Expand Up @@ -82,9 +84,7 @@ def _estimator_has(attr):
)


class SelectFromModel(
_RoutingNotSupportedMixin, MetaEstimatorMixin, SelectorMixin, BaseEstimator
):
class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
"""Meta-transformer for selecting features based on importance weights.

.. versionadded:: 0.17
Expand Down Expand Up @@ -341,14 +341,25 @@ def fit(self, X, y=None, **fit_params):
classification, real numbers in regression).

**fit_params : dict
Other estimator specific parameters.
- If `enable_metadata_routing=False` (default):

Parameters directly passed to the `partial_fit` method of the
sub-estimator. They are ignored if `prefit=True`.

- If `enable_metadata_routing=True`:

Parameters safely routed to the `partial_fit` method of the
sub-estimator. They are ignored if `prefit=True`.

.. versionchanged:: 1.4
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Returns
-------
self : object
Fitted estimator.
"""
_raise_for_unsupported_routing(self, "fit", **fit_params)
self._check_max_features(X)

if self.prefit:
Expand All @@ -361,8 +372,14 @@ def fit(self, X, y=None, **fit_params):
) from exc
self.estimator_ = deepcopy(self.estimator)
else:
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X, y, **fit_params)
if _routing_enabled():
routed_params = process_routing(self, "fit", **fit_params)
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X, y, **routed_params.estimator.fit)
else:
# TODO(SLEP6): remove when metadata routing cannot be disabled.
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X, y, **fit_params)

if hasattr(self.estimator_, "feature_names_in_"):
self.feature_names_in_ = self.estimator_.feature_names_in_
Expand All @@ -387,7 +404,7 @@ def threshold_(self):
# SelectFromModel.estimator is not validated yet
prefer_skip_nested_validation=False
)
def partial_fit(self, X, y=None, **fit_params):
def partial_fit(self, X, y=None, **partial_fit_params):
"""Fit the SelectFromModel meta-transformer only once.

Parameters
Expand All @@ -399,8 +416,24 @@ def partial_fit(self, X, y=None, **fit_params):
The target values (integers that correspond to classes in
classification, real numbers in regression).

**fit_params : dict
Other estimator specific parameters.
**partial_fit_params : dict
- If `enable_metadata_routing=False` (default):

Parameters directly passed to the `partial_fit` method of the
sub-estimator.

- If `enable_metadata_routing=True`:

Parameters passed to the `partial_fit` method of the
sub-estimator. They are ignored if `prefit=True`.

.. versionchanged:: 1.4
`**partial_fit_params` are routed to the sub-estimator, if
`enable_metadata_routing=True` is set via
:func:`~sklearn.set_config`, which allows for aliasing.
Comment on lines +431 to +433
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`**partial_fit_params` are routed to the sub-estimator, if
`enable_metadata_routing=True` is set via
:func:`~sklearn.set_config`, which allows for aliasing.
Only available if `enable_metadata_routing=True`,
which can be set by using
``sklearn.set_config(enable_metadata_routing=True)``.
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I would suggest to leave this part below .. versionchanged:: 1.4 mostly away. We already wrote all of this before and after. And it looks difficult to read.

What about this?

         **partial_fit_params : dict
            - If `enable_metadata_routing=False` (default):

                Parameters directly passed to the `partial_fit` method of the
                sub-estimator. They are ignored if `prefit=True`.

            - If `enable_metadata_routing=True`:

                Parameters safely routed to the `partial_fit` method of the
                sub-estimator. They are ignored if `prefit=True`.

                .. versionchanged:: 1.4
                    See :ref:`Metadata Routing User Guide <metadata_routing>` for
                    more details.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.


See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Returns
-------
Expand All @@ -426,7 +459,13 @@ def partial_fit(self, X, y=None, **fit_params):

if first_call:
self.estimator_ = clone(self.estimator)
self.estimator_.partial_fit(X, y, **fit_params)
if _routing_enabled():
routed_params = process_routing(self, "partial_fit", **partial_fit_params)
self.estimator_ = clone(self.estimator)
self.estimator_.partial_fit(X, y, **routed_params.estimator.partial_fit)
else:
# TODO(SLEP6): remove when metadata routing cannot be disabled.
self.estimator_.partial_fit(X, y, **partial_fit_params)

if hasattr(self.estimator_, "feature_names_in_"):
self.feature_names_in_ = self.estimator_.feature_names_in_
Expand All @@ -451,5 +490,27 @@ def n_features_in_(self):

return self.estimator_.n_features_in_

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

.. versionadded:: 1.4

Returns
-------
routing : MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.estimator,
method_mapping=MethodMapping()
.add(callee="partial_fit", caller="partial_fit")
.add(callee="fit", caller="fit"),
)
return router

def _more_tags(self):
return {"allow_nan": _safe_tags(self.estimator, key="allow_nan")}
13 changes: 12 additions & 1 deletion sklearn/tests/test_metaestimators_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from sklearn import config_context
from sklearn.base import is_classifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.compose import TransformedTargetRegressor
from sklearn.covariance import GraphicalLassoCV
Expand Down Expand Up @@ -228,6 +229,15 @@ def enable_slep006():
"y": y,
"estimator_routing_methods": ["fit"],
},
{
"metaestimator": SelectFromModel,
"estimator_name": "estimator",
"estimator": ConsumingClassifier,
"X": X,
"y": y,
"estimator_routing_methods": ["fit", "partial_fit"],
"method_args": {"partial_fit": {"classes": classes}},
},
{
"metaestimator": OrthogonalMatchingPursuitCV,
"X": X,
Expand Down Expand Up @@ -325,7 +335,6 @@ def enable_slep006():
RFECV(ConsumingClassifier()),
RidgeCV(),
RidgeClassifierCV(),
SelectFromModel(ConsumingClassifier()),
SelfTrainingClassifier(ConsumingClassifier()),
SequentialFeatureSelector(ConsumingClassifier()),
StackingClassifier(ConsumingClassifier()),
Expand Down Expand Up @@ -477,6 +486,8 @@ def set_request(estimator, method_name):
# e.g. call set_fit_request on estimator
set_request_for_method = getattr(estimator, f"set_{method_name}_request")
set_request_for_method(sample_weight=True, metadata=True)
if is_classifier(estimator) and method_name == "partial_fit":
set_request_for_method(classes=True)

cls = metaestimator["metaestimator"]
X = metaestimator["X"]
Expand Down