Skip to content

Commit 5fb02bb

Browse files
authored
ENH Adds n_features_in_ checks to impute module (scikit-learn#18580)
* ENH Adds n_features_in_ checks to impute module
1 parent e970678 commit 5fb02bb

File tree

4 files changed

+11
-16
lines changed

4 files changed

+11
-16
lines changed

sklearn/impute/_base.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -793,16 +793,12 @@ def transform(self, X):
793793
# Need not validate X again as it would have already been validated
794794
# in the Imputer calling MissingIndicator
795795
if not self._precomputed:
796-
X = self._validate_input(X, in_fit=True)
796+
X = self._validate_input(X, in_fit=False)
797797
else:
798798
if not (hasattr(X, 'dtype') and X.dtype.kind == 'b'):
799799
raise ValueError("precomputed is True but the input data is "
800800
"not a mask")
801801

802-
if X.shape[1] != self._n_features:
803-
raise ValueError("X has a different number of features "
804-
"than during fitting.")
805-
806802
imputer_mask, features = self._get_missing_features_info(X)
807803

808804
if self.features == "missing-only":

sklearn/impute/_iterative.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def _get_abs_corr_mat(self, X_filled, tolerance=1e-6):
474474
abs_corr_mat = normalize(abs_corr_mat, norm='l1', axis=0, copy=False)
475475
return abs_corr_mat
476476

477-
def _initial_imputation(self, X):
477+
def _initial_imputation(self, X, in_fit=False):
478478
"""Perform initial imputation for input X.
479479
480480
Parameters
@@ -483,6 +483,9 @@ def _initial_imputation(self, X):
483483
Input data, where "n_samples" is the number of samples and
484484
"n_features" is the number of features.
485485
486+
in_fit : bool, default=False
487+
Whether function is called in fit.
488+
486489
Returns
487490
-------
488491
Xt : ndarray, shape (n_samples, n_features)
@@ -506,7 +509,7 @@ def _initial_imputation(self, X):
506509
else:
507510
force_all_finite = True
508511

509-
X = self._validate_data(X, dtype=FLOAT_DTYPES, order="F",
512+
X = self._validate_data(X, dtype=FLOAT_DTYPES, order="F", reset=in_fit,
510513
force_all_finite=force_all_finite)
511514
_check_inputs_dtype(X, self.missing_values)
512515

@@ -600,7 +603,8 @@ def fit_transform(self, X, y=None):
600603

601604
self.initial_imputer_ = None
602605

603-
X, Xt, mask_missing_values, complete_mask = self._initial_imputation(X)
606+
X, Xt, mask_missing_values, complete_mask = (
607+
self._initial_imputation(X, in_fit=True))
604608

605609
super()._fit_indicator(complete_mask)
606610
X_indicator = super()._transform_indicator(complete_mask)

sklearn/impute/_knn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from ..metrics.pairwise import _NAN_METRICS
1111
from ..neighbors._base import _get_weights
1212
from ..neighbors._base import _check_weights
13-
from ..utils import check_array
1413
from ..utils import is_scalar_nan
1514
from ..utils._mask import _get_mask
1615
from ..utils.validation import check_is_fitted
@@ -213,12 +212,9 @@ def transform(self, X):
213212
force_all_finite = True
214213
else:
215214
force_all_finite = "allow-nan"
216-
X = check_array(X, accept_sparse=False, dtype=FLOAT_DTYPES,
217-
force_all_finite=force_all_finite, copy=self.copy)
218-
219-
if X.shape[1] != self._fit_X.shape[1]:
220-
raise ValueError("Incompatible dimension between the fitted "
221-
"dataset and the one to be transformed")
215+
X = self._validate_data(X, accept_sparse=False, dtype=FLOAT_DTYPES,
216+
force_all_finite=force_all_finite,
217+
copy=self.copy, reset=False)
222218

223219
mask = _get_mask(X, self.missing_values)
224220
mask_fit_X = self._mask_fit_X

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,6 @@ def test_search_cv(estimator, check, request):
344344
'feature_extraction',
345345
'feature_selection',
346346
'gaussian_process',
347-
'impute',
348347
'isotonic',
349348
'linear_model',
350349
'manifold',

0 commit comments

Comments
 (0)