11
11
from ..utils ._param_validation import HasMethods , Interval , Options
12
12
from ..utils ._tags import _safe_tags
13
13
from ..utils .metadata_routing import (
14
- _raise_for_unsupported_routing ,
15
- _RoutingNotSupportedMixin ,
14
+ MetadataRouter ,
15
+ MethodMapping ,
16
+ _routing_enabled ,
17
+ process_routing ,
16
18
)
17
19
from ..utils .metaestimators import available_if
18
20
from ..utils .validation import _num_features , check_is_fitted , check_scalar
@@ -82,9 +84,7 @@ def _estimator_has(attr):
82
84
)
83
85
84
86
85
- class SelectFromModel (
86
- _RoutingNotSupportedMixin , MetaEstimatorMixin , SelectorMixin , BaseEstimator
87
- ):
87
+ class SelectFromModel (MetaEstimatorMixin , SelectorMixin , BaseEstimator ):
88
88
"""Meta-transformer for selecting features based on importance weights.
89
89
90
90
.. versionadded:: 0.17
@@ -341,14 +341,25 @@ def fit(self, X, y=None, **fit_params):
341
341
classification, real numbers in regression).
342
342
343
343
**fit_params : dict
344
- Other estimator specific parameters.
344
+ - If `enable_metadata_routing=False` (default):
345
+
346
+ Parameters directly passed to the `partial_fit` method of the
347
+ sub-estimator. They are ignored if `prefit=True`.
348
+
349
+ - If `enable_metadata_routing=True`:
350
+
351
+ Parameters safely routed to the `partial_fit` method of the
352
+ sub-estimator. They are ignored if `prefit=True`.
353
+
354
+ .. versionchanged:: 1.4
355
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for
356
+ more details.
345
357
346
358
Returns
347
359
-------
348
360
self : object
349
361
Fitted estimator.
350
362
"""
351
- _raise_for_unsupported_routing (self , "fit" , ** fit_params )
352
363
self ._check_max_features (X )
353
364
354
365
if self .prefit :
@@ -361,8 +372,14 @@ def fit(self, X, y=None, **fit_params):
361
372
) from exc
362
373
self .estimator_ = deepcopy (self .estimator )
363
374
else :
364
- self .estimator_ = clone (self .estimator )
365
- self .estimator_ .fit (X , y , ** fit_params )
375
+ if _routing_enabled ():
376
+ routed_params = process_routing (self , "fit" , ** fit_params )
377
+ self .estimator_ = clone (self .estimator )
378
+ self .estimator_ .fit (X , y , ** routed_params .estimator .fit )
379
+ else :
380
+ # TODO(SLEP6): remove when metadata routing cannot be disabled.
381
+ self .estimator_ = clone (self .estimator )
382
+ self .estimator_ .fit (X , y , ** fit_params )
366
383
367
384
if hasattr (self .estimator_ , "feature_names_in_" ):
368
385
self .feature_names_in_ = self .estimator_ .feature_names_in_
@@ -387,7 +404,7 @@ def threshold_(self):
387
404
# SelectFromModel.estimator is not validated yet
388
405
prefer_skip_nested_validation = False
389
406
)
390
- def partial_fit (self , X , y = None , ** fit_params ):
407
+ def partial_fit (self , X , y = None , ** partial_fit_params ):
391
408
"""Fit the SelectFromModel meta-transformer only once.
392
409
393
410
Parameters
@@ -399,8 +416,24 @@ def partial_fit(self, X, y=None, **fit_params):
399
416
The target values (integers that correspond to classes in
400
417
classification, real numbers in regression).
401
418
402
- **fit_params : dict
403
- Other estimator specific parameters.
419
+ **partial_fit_params : dict
420
+ - If `enable_metadata_routing=False` (default):
421
+
422
+ Parameters directly passed to the `partial_fit` method of the
423
+ sub-estimator.
424
+
425
+ - If `enable_metadata_routing=True`:
426
+
427
+ Parameters passed to the `partial_fit` method of the
428
+ sub-estimator. They are ignored if `prefit=True`.
429
+
430
+ .. versionchanged:: 1.4
431
+ `**partial_fit_params` are routed to the sub-estimator, if
432
+ `enable_metadata_routing=True` is set via
433
+ :func:`~sklearn.set_config`, which allows for aliasing.
434
+
435
+ See :ref:`Metadata Routing User Guide <metadata_routing>` for
436
+ more details.
404
437
405
438
Returns
406
439
-------
@@ -426,7 +459,13 @@ def partial_fit(self, X, y=None, **fit_params):
426
459
427
460
if first_call :
428
461
self .estimator_ = clone (self .estimator )
429
- self .estimator_ .partial_fit (X , y , ** fit_params )
462
+ if _routing_enabled ():
463
+ routed_params = process_routing (self , "partial_fit" , ** partial_fit_params )
464
+ self .estimator_ = clone (self .estimator )
465
+ self .estimator_ .partial_fit (X , y , ** routed_params .estimator .partial_fit )
466
+ else :
467
+ # TODO(SLEP6): remove when metadata routing cannot be disabled.
468
+ self .estimator_ .partial_fit (X , y , ** partial_fit_params )
430
469
431
470
if hasattr (self .estimator_ , "feature_names_in_" ):
432
471
self .feature_names_in_ = self .estimator_ .feature_names_in_
@@ -451,5 +490,27 @@ def n_features_in_(self):
451
490
452
491
return self .estimator_ .n_features_in_
453
492
493
+ def get_metadata_routing (self ):
494
+ """Get metadata routing of this object.
495
+
496
+ Please check :ref:`User Guide <metadata_routing>` on how the routing
497
+ mechanism works.
498
+
499
+ .. versionadded:: 1.4
500
+
501
+ Returns
502
+ -------
503
+ routing : MetadataRouter
504
+ A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
505
+ routing information.
506
+ """
507
+ router = MetadataRouter (owner = self .__class__ .__name__ ).add (
508
+ estimator = self .estimator ,
509
+ method_mapping = MethodMapping ()
510
+ .add (callee = "partial_fit" , caller = "partial_fit" )
511
+ .add (callee = "fit" , caller = "fit" ),
512
+ )
513
+ return router
514
+
454
515
def _more_tags (self ):
455
516
return {"allow_nan" : _safe_tags (self .estimator , key = "allow_nan" )}
0 commit comments