diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d9dcf8757bc68..ce9da1d4c989d 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -807,6 +807,9 @@ Changelog `n_features_in_` and will be removed in 1.2. :pr:`20240` by :user:`Jérémie du Boisberranger `. +- |Fix| :class:`preprocessing.FunctionTransformer` does not set `n_features_in_` + based on the input to `inverse_transform`. :pr:`20961` by `Thomas Fan`_. + :mod:`sklearn.svm` ................... diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index 20ee90f5f253f..d975f63e32fe2 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -1,7 +1,7 @@ import warnings from ..base import BaseEstimator, TransformerMixin -from ..utils.validation import _allclose_dense_sparse +from ..utils.validation import _allclose_dense_sparse, check_array def _identity(X): @@ -71,6 +71,20 @@ class FunctionTransformer(TransformerMixin, BaseEstimator): .. versionadded:: 0.18 + Attributes + ---------- + n_features_in_ : int + Number of features seen during :term:`fit`. Defined only when + `validate=True`. + + .. versionadded:: 0.24 + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `validate=True` + and `X` has feature names that are all strings. + + .. versionadded:: 1.0 + See Also -------- MaxAbsScaler : Scale each feature by its maximum absolute value. @@ -110,9 +124,9 @@ def __init__( self.kw_args = kw_args self.inv_kw_args = inv_kw_args - def _check_input(self, X): + def _check_input(self, X, *, reset): if self.validate: - return self._validate_data(X, accept_sparse=self.accept_sparse) + return self._validate_data(X, accept_sparse=self.accept_sparse, reset=reset) return X def _check_inverse_transform(self, X): @@ -146,7 +160,7 @@ def fit(self, X, y=None): self : object FunctionTransformer class instance. """ - X = self._check_input(X) + X = self._check_input(X, reset=True) if self.check_inverse and not (self.func is None or self.inverse_func is None): self._check_inverse_transform(X) return self @@ -164,6 +178,7 @@ def transform(self, X): X_out : array-like, shape (n_samples, n_features) Transformed input. """ + X = self._check_input(X, reset=False) return self._transform(X, func=self.func, kw_args=self.kw_args) def inverse_transform(self, X): @@ -179,11 +194,11 @@ def inverse_transform(self, X): X_out : array-like, shape (n_samples, n_features) Transformed input. """ + if self.validate: + X = check_array(X, accept_sparse=self.accept_sparse) return self._transform(X, func=self.inverse_func, kw_args=self.inv_kw_args) def _transform(self, X, func=None, kw_args=None): - X = self._check_input(X) - if func is None: func = _identity diff --git a/sklearn/preprocessing/tests/test_function_transformer.py b/sklearn/preprocessing/tests/test_function_transformer.py index b3e517ac0c36c..b1ba9ebe6b762 100644 --- a/sklearn/preprocessing/tests/test_function_transformer.py +++ b/sklearn/preprocessing/tests/test_function_transformer.py @@ -174,3 +174,27 @@ def test_function_transformer_frame(): transformer = FunctionTransformer() X_df_trans = transformer.fit_transform(X_df) assert hasattr(X_df_trans, "loc") + + +def test_function_transformer_validate_inverse(): + """Test that function transformer does not reset estimator in + `inverse_transform`.""" + + def add_constant_feature(X): + X_one = np.ones((X.shape[0], 1)) + return np.concatenate((X, X_one), axis=1) + + def inverse_add_constant(X): + return X[:, :-1] + + X = np.array([[1, 2], [3, 4], [3, 4]]) + trans = FunctionTransformer( + func=add_constant_feature, + inverse_func=inverse_add_constant, + validate=True, + ) + X_trans = trans.fit_transform(X) + assert trans.n_features_in_ == X.shape[1] + + trans.inverse_transform(X_trans) + assert trans.n_features_in_ == X.shape[1]