10
10
from ._base import _get_feature_importances
11
11
from ..base import BaseEstimator , clone , MetaEstimatorMixin
12
12
from ..utils ._tags import _safe_tags
13
- from ..utils .validation import check_is_fitted
13
+ from ..utils .validation import check_is_fitted , check_scalar , _num_features
14
14
15
15
from ..exceptions import NotFittedError
16
16
from ..utils .metaestimators import available_if
17
- from ..utils .validation import check_scalar
18
17
19
18
20
19
def _calculate_threshold (estimator , importances , threshold ):
@@ -287,34 +286,17 @@ def _get_support_mask(self):
287
286
mask [scores < threshold ] = False
288
287
return mask
289
288
290
- def fit (self , X , y = None , ** fit_params ):
291
- """Fit the SelectFromModel meta-transformer.
292
-
293
- Parameters
294
- ----------
295
- X : array-like of shape (n_samples, n_features)
296
- The training input samples.
297
-
298
- y : array-like of shape (n_samples,), default=None
299
- The target values (integers that correspond to classes in
300
- classification, real numbers in regression).
301
-
302
- **fit_params : dict
303
- Other estimator specific parameters.
304
-
305
- Returns
306
- -------
307
- self : object
308
- Fitted estimator.
309
- """
289
+ def _check_max_features (self , X ):
310
290
if self .max_features is not None :
291
+ n_features = _num_features (X )
292
+
311
293
if isinstance (self .max_features , numbers .Integral ):
312
294
check_scalar (
313
295
self .max_features ,
314
296
"max_features" ,
315
297
numbers .Integral ,
316
298
min_val = 0 ,
317
- max_val = len ( X [ 0 ]) ,
299
+ max_val = n_features ,
318
300
)
319
301
self .max_features_ = self .max_features
320
302
elif callable (self .max_features ):
@@ -324,7 +306,7 @@ def fit(self, X, y=None, **fit_params):
324
306
"max_features(X)" ,
325
307
numbers .Integral ,
326
308
min_val = 0 ,
327
- max_val = len ( X [ 0 ]) ,
309
+ max_val = n_features ,
328
310
)
329
311
self .max_features_ = max_features
330
312
else :
@@ -333,6 +315,28 @@ def fit(self, X, y=None, **fit_params):
333
315
f" 'X' as input. Got { self .max_features } instead."
334
316
)
335
317
318
+ def fit (self , X , y = None , ** fit_params ):
319
+ """Fit the SelectFromModel meta-transformer.
320
+
321
+ Parameters
322
+ ----------
323
+ X : array-like of shape (n_samples, n_features)
324
+ The training input samples.
325
+
326
+ y : array-like of shape (n_samples,), default=None
327
+ The target values (integers that correspond to classes in
328
+ classification, real numbers in regression).
329
+
330
+ **fit_params : dict
331
+ Other estimator specific parameters.
332
+
333
+ Returns
334
+ -------
335
+ self : object
336
+ Fitted estimator.
337
+ """
338
+ self ._check_max_features (X )
339
+
336
340
if self .prefit :
337
341
try :
338
342
check_is_fitted (self .estimator )
@@ -385,6 +389,8 @@ def partial_fit(self, X, y=None, **fit_params):
385
389
self : object
386
390
Fitted estimator.
387
391
"""
392
+ self ._check_max_features (X )
393
+
388
394
if self .prefit :
389
395
if not hasattr (self , "estimator_" ):
390
396
try :
@@ -397,9 +403,16 @@ def partial_fit(self, X, y=None, **fit_params):
397
403
self .estimator_ = deepcopy (self .estimator )
398
404
return self
399
405
400
- if not hasattr (self , "estimator_" ):
406
+ first_call = not hasattr (self , "estimator_" )
407
+ if first_call :
401
408
self .estimator_ = clone (self .estimator )
402
409
self .estimator_ .partial_fit (X , y , ** fit_params )
410
+
411
+ if hasattr (self .estimator_ , "feature_names_in_" ):
412
+ self .feature_names_in_ = self .estimator_ .feature_names_in_
413
+ else :
414
+ self ._check_feature_names (X , reset = first_call )
415
+
403
416
return self
404
417
405
418
@property
0 commit comments