diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index b4fcbee992383..0db9f11adab7c 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -134,6 +134,12 @@ Changelog - |API| `np.matrix` usage is deprecated in 1.0 and will raise a `TypeError` in 1.2. :pr:`20165` by `Thomas Fan`_. +- |API| All estimators store `feature_names_in_` when fitted on pandas Dataframes. + These feature names are compared to names seen in `non-fit` methods, + `i.e.` `transform` and will raise a `FutureWarning` if they are not consistent. + These `FutureWarning`s will become `ValueError`s in 1.2. + :pr:`18010` by `Thomas Fan`_. + :mod:`sklearn.base` ................... diff --git a/sklearn/base.py b/sklearn/base.py index 6730ea8fd4590..a585b2b06c394 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -24,6 +24,7 @@ from .utils.validation import _check_y from .utils.validation import _num_features from .utils._estimator_html_repr import estimator_html_repr +from .utils.validation import _get_feature_names def clone(estimator, *, safe=True): @@ -395,6 +396,92 @@ def _check_n_features(self, X, reset): f"is expecting {self.n_features_in_} features as input." ) + def _check_feature_names(self, X, *, reset): + """Set or check the `feature_names_in_` attribute. + + .. versionadded:: 1.0 + + Parameters + ---------- + X : {ndarray, dataframe} of shape (n_samples, n_features) + The input samples. + + reset : bool + Whether to reset the `feature_names_in_` attribute. + If False, the input will be checked for consistency with + feature names of data provided when reset was last True. + .. note:: + It is recommended to call `reset=True` in `fit` and in the first + call to `partial_fit`. All other methods that validate `X` + should set `reset=False`. + """ + + if reset: + feature_names_in = _get_feature_names(X) + if feature_names_in is not None: + self.feature_names_in_ = feature_names_in + return + + fitted_feature_names = getattr(self, "feature_names_in_", None) + X_feature_names = _get_feature_names(X) + + if fitted_feature_names is None and X_feature_names is None: + # no feature names seen in fit and in X + return + + if X_feature_names is not None and fitted_feature_names is None: + warnings.warn( + f"X has feature names, but {self.__class__.__name__} was fitted without" + " feature names" + ) + return + + if X_feature_names is None and fitted_feature_names is not None: + warnings.warn( + "X does not have valid feature names, but" + f" {self.__class__.__name__} was fitted with feature names" + ) + return + + # validate the feature names against the `feature_names_in_` attribute + if len(fitted_feature_names) != len(X_feature_names) or np.any( + fitted_feature_names != X_feature_names + ): + message = ( + "The feature names should match those that were " + "passed during fit. Starting version 1.2, an error will be raised.\n" + ) + fitted_feature_names_set = set(fitted_feature_names) + X_feature_names_set = set(X_feature_names) + + unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set) + missing_names = sorted(fitted_feature_names_set - X_feature_names_set) + + def add_names(names): + output = "" + max_n_names = 5 + for i, name in enumerate(names): + if i >= max_n_names: + output += "- ...\n" + break + output += f"- {name}\n" + return output + + if unexpected_names: + message += "Feature names unseen at fit time:\n" + message += add_names(unexpected_names) + + if missing_names: + message += "Feature names seen at fit time, yet now missing:\n" + message += add_names(missing_names) + + if not missing_names and not missing_names: + message += ( + "Feature names must be in the same order as they were in fit.\n" + ) + + warnings.warn(message, FutureWarning) + def _validate_data( self, X="no_validation", @@ -452,6 +539,8 @@ def _validate_data( The validated input. A tuple is returned if both `X` and `y` are validated. """ + self._check_feature_names(X, reset=reset) + if y is None and self._get_tags()["requires_y"]: raise ValueError( f"This {self.__class__.__name__} estimator " diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 12d643f6e21dc..add834332dc43 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -368,6 +368,8 @@ def fit(self, X, y, sample_weight=None): first_clf = self.calibrated_classifiers_[0].base_estimator if hasattr(first_clf, "n_features_in_"): self.n_features_in_ = first_clf.n_features_in_ + if hasattr(first_clf, "feature_names_in_"): + self.feature_names_in_ = first_clf.feature_names_in_ return self def predict_proba(self, X): diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 8b2a63dca380d..2b15376243c66 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -257,6 +257,8 @@ def fit(self, X, y=None, **fit_params): raise NotFittedError("Since 'prefit=True', call transform directly") self.estimator_ = clone(self.estimator) self.estimator_.fit(X, y, **fit_params) + if hasattr(self.estimator_, "feature_names_in_"): + self.feature_names_in_ = self.estimator_.feature_names_in_ return self @property diff --git a/sklearn/kernel_approximation.py b/sklearn/kernel_approximation.py index dc68b2d773611..40f451d882c2a 100644 --- a/sklearn/kernel_approximation.py +++ b/sklearn/kernel_approximation.py @@ -21,7 +21,7 @@ from .base import BaseEstimator from .base import TransformerMixin -from .utils import check_random_state, as_float_array +from .utils import check_random_state from .utils.extmath import safe_sparse_dot from .utils.validation import check_is_fitted from .metrics.pairwise import pairwise_kernels, KERNEL_PARAMS @@ -450,9 +450,9 @@ def transform(self, X): Projected array. """ check_is_fitted(self) - - X = as_float_array(X, copy=True) - X = self._validate_data(X, copy=False, reset=False) + X = self._validate_data( + X, copy=True, dtype=[np.float64, np.float32], reset=False + ) if (X <= -self.skewedness).any(): raise ValueError("X may not contain entries smaller than -skewedness.") diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index 71d2ca291fcbe..88050bfce383e 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -539,6 +539,7 @@ def predict(self, X): Returns predicted values. """ check_is_fitted(self) + self._check_feature_names(X, reset=False) return self.estimator_.predict(X) @@ -561,6 +562,7 @@ def score(self, X, y): Score of the prediction. """ check_is_fitted(self) + self._check_feature_names(X, reset=False) return self.estimator_.score(X, y) diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index 67d9a47881953..ee9c14501400f 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1983,6 +1983,8 @@ def fit(self, X, y, sample_weight=None): self.coef_ = estimator.coef_ self.intercept_ = estimator.intercept_ self.n_features_in_ = estimator.n_features_in_ + if hasattr(estimator, "feature_names_in_"): + self.feature_names_in_ = estimator.feature_names_in_ return self diff --git a/sklearn/manifold/_isomap.py b/sklearn/manifold/_isomap.py index 920a0a5503326..52a0e40c72a2c 100644 --- a/sklearn/manifold/_isomap.py +++ b/sklearn/manifold/_isomap.py @@ -172,6 +172,8 @@ def _fit_transform(self, X): ) self.nbrs_.fit(X) self.n_features_in_ = self.nbrs_.n_features_in_ + if hasattr(self.nbrs_, "feature_names_in_"): + self.feature_names_in_ = self.nbrs_.feature_names_in_ self.kernel_pca_ = KernelPCA( n_components=self.n_components, diff --git a/sklearn/manifold/_locally_linear.py b/sklearn/manifold/_locally_linear.py index 1f3f6680dc773..7c80340191877 100644 --- a/sklearn/manifold/_locally_linear.py +++ b/sklearn/manifold/_locally_linear.py @@ -768,7 +768,7 @@ def transform(self, X): """ check_is_fitted(self) - X = check_array(X) + X = self._validate_data(X, reset=False) ind = self.nbrs_.kneighbors( X, n_neighbors=self.n_neighbors, return_distance=False ) diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index 3d8647e3960f6..6d9afeb10de63 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -15,7 +15,6 @@ from ..base import BaseEstimator from ..base import TransformerMixin -from ..utils import check_array from ..utils import check_random_state from ..utils import gen_even_slices from ..utils.extmath import safe_sparse_dot @@ -333,7 +332,7 @@ def score_samples(self, X): """ check_is_fitted(self) - v = check_array(X, accept_sparse="csr") + v = self._validate_data(X, accept_sparse="csr", reset=False) rng = check_random_state(self.random_state) # Randomly corrupt one feature in each sample in v. diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index f94a2fc86c81d..36dc6064e4a46 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -1,6 +1,7 @@ # Author: Gael Varoquaux # License: BSD 3 clause +import re import numpy as np import scipy.sparse as sp import pytest @@ -615,3 +616,73 @@ def test_n_features_in_no_validation(): # does not raise est._check_n_features("invalid X", reset=False) + + +def test_feature_names_in(): + """Check that feature_name_in are recorded by `_validate_data`""" + pd = pytest.importorskip("pandas") + iris = datasets.load_iris() + X_np = iris.data + df = pd.DataFrame(X_np, columns=iris.feature_names) + + class NoOpTransformer(TransformerMixin, BaseEstimator): + def fit(self, X, y=None): + self._validate_data(X) + return self + + def transform(self, X): + self._validate_data(X, reset=False) + return X + + # fit on dataframe saves the feature names + trans = NoOpTransformer().fit(df) + assert_array_equal(trans.feature_names_in_, df.columns) + + msg = "The feature names should match those that were passed" + df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1]) + with pytest.warns(FutureWarning, match=msg): + trans.transform(df_bad) + + # warns when fitted on dataframe and transforming a ndarray + msg = ( + "X does not have valid feature names, but NoOpTransformer was " + "fitted with feature names" + ) + with pytest.warns(UserWarning, match=msg): + trans.transform(X_np) + + # warns when fitted on a ndarray and transforming dataframe + msg = "X has feature names, but NoOpTransformer was fitted without feature names" + trans = NoOpTransformer().fit(X_np) + with pytest.warns(UserWarning, match=msg): + trans.transform(df) + + # fit on dataframe with all integer feature names works without warning + df_int_names = pd.DataFrame(X_np) + trans = NoOpTransformer() + with pytest.warns(None) as record: + trans.fit(df_int_names) + assert not record + + # fit on dataframe with no feature names or all integer feature names + # -> do not warn on trainsform + Xs = [X_np, df_int_names] + for X in Xs: + with pytest.warns(None) as record: + trans.transform(X) + assert not record + + # TODO: Convert to a error in 1.2 + # fit on dataframe with feature names that are mixed warns: + df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2]) + trans = NoOpTransformer() + msg = re.escape( + "Feature names only support names that are all strings. " + "Got feature names with dtypes: ['int', 'str']" + ) + with pytest.warns(FutureWarning, match=msg) as record: + trans.fit(df_mixed) + + # transform on feature names that are mixed also warns: + with pytest.warns(FutureWarning, match=msg) as record: + trans.transform(df_mixed) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 1d6700cf46ded..008bdee7e646b 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -47,6 +47,7 @@ _get_check_estimator_ids, check_class_weight_balanced_linear_classifier, parametrize_with_checks, + check_dataframe_column_names_consistency, check_n_features_in_after_fitting, ) @@ -313,3 +314,41 @@ def test_search_cv(estimator, check, request): def test_check_n_features_in_after_fitting(estimator): _set_checking_parameters(estimator) check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) + + +# TODO: When more modules get added, we can remove it from this list to make +# sure it gets tested. After we finish each module we can move the checks +# into check_estimator. +# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that +# delegates validation to a base estimator, the check is testing that the base estimator +# is checking for column name consistency. + +COLUMN_NAME_MODULES_TO_IGNORE = { + "compose", + "ensemble", + "feature_extraction", + "kernel_approximation", + "model_selection", + "multiclass", + "multioutput", + "pipeline", + "semi_supervised", +} + + +column_name_estimators = [ + est + for est in _tested_estimators() + if est.__module__.split(".")[1] not in COLUMN_NAME_MODULES_TO_IGNORE +] + + +@pytest.mark.parametrize( + "estimator", column_name_estimators, ids=_get_check_estimator_ids +) +def test_pandas_column_name_consistency(estimator): + _set_checking_parameters(estimator) + with ignore_warnings(category=(FutureWarning)): + check_dataframe_column_names_consistency( + estimator.__class__.__name__, estimator + ) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 7749484ea5b22..3a79584eac7f7 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3648,3 +3648,119 @@ def check_estimator_get_tags_default_keys(name, estimator_orig): f"{name}._get_tags() is missing entries for the following default tags" f": {default_tags_keys - tags_keys.intersection(default_tags_keys)}" ) + + +def check_dataframe_column_names_consistency(name, estimator_orig): + try: + import pandas as pd + except ImportError: + raise SkipTest( + "pandas is not installed: not checking column name consistency for pandas" + ) + + tags = _safe_tags(estimator_orig) + + if ( + "2darray" not in tags["X_types"] + and "sparse" not in tags["X_types"] + or tags["no_validation"] + ): + return + + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + + X_orig = rng.normal(size=(150, 8)) + X_orig = _enforce_estimator_tags_x(estimator, X_orig) + X_orig = _pairwise_estimator_convert_X(X_orig, estimator) + n_samples, n_features = X_orig.shape + + names = np.array([f"col_{i}" for i in range(n_features)]) + X = pd.DataFrame(X_orig, columns=names) + + if is_regressor(estimator): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + estimator.fit(X, y) + + if not hasattr(estimator, "feature_names_in_"): + raise ValueError( + "Estimator does not have a feature_names_in_ " + "attribute after fitting with a dataframe" + ) + assert_array_equal(estimator.feature_names_in_, names) + + check_methods = [] + for method in ( + "predict", + "transform", + "decision_function", + "predict_proba", + "score", + "score_samples", + "predict_log_proba", + ): + if not hasattr(estimator, method): + continue + + callable_method = getattr(estimator, method) + if method == "score": + callable_method = partial(callable_method, y=y) + check_methods.append((method, callable_method)) + + for _, method in check_methods: + method(X) # works + + invalid_names = [ + (names[::-1], "Feature names must be in the same order as they were in fit."), + ( + [f"another_prefix_{i}" for i in range(n_features)], + "Feature names unseen at fit time:\n- another_prefix_0\n-" + " another_prefix_1\n", + ), + ( + names[:3], + f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n", + ), + ] + + for invalid_name, additional_message in invalid_names: + X_bad = pd.DataFrame(X, columns=invalid_name) + + expected_msg = re.escape( + "The feature names should match those that were passed " + "during fit. Starting version 1.2, an error will be raised.\n" + f"{additional_message}" + ) + for name, method in check_methods: + # TODO In 1.2, this will be an error. + with warnings.catch_warnings(): + warnings.filterwarnings( + "error", + category=FutureWarning, + module="sklearn", + ) + with raises( + FutureWarning, match=expected_msg, err_msg=f"{name} did not raise" + ): + method(X_bad) + + # partial_fit checks on second call + if not hasattr(estimator, "partial_fit"): + continue + + estimator = clone(estimator_orig) + if is_classifier(estimator): + classes = np.unique(y) + estimator.partial_fit(X, y, classes=classes) + else: + estimator.partial_fit(X, y) + + with warnings.catch_warnings(): + warnings.filterwarnings("error", category=FutureWarning, module="sklearn") + with raises(FutureWarning, match=expected_msg): + estimator.partial_fit(X_bad, y) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ea158234ea785..3d565ca5895ef 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -43,6 +43,7 @@ check_classifiers_multilabel_output_format_decision_function, check_classifiers_multilabel_output_format_predict, check_classifiers_multilabel_output_format_predict_proba, + check_dataframe_column_names_consistency, check_estimator, check_estimator_get_tags_default_keys, check_estimators_unfitted, @@ -411,6 +412,18 @@ def _more_tags(self): return {"poor_score": True} +class PartialFitChecksName(BaseEstimator): + def fit(self, X, y): + self._validate_data(X, y) + return self + + def partial_fit(self, X, y): + reset = not hasattr(self, "_fitted") + self._validate_data(X, y, reset=reset) + self._fitted = True + return self + + def test_not_an_array_array_function(): if np_version < parse_version("1.17"): raise SkipTest("array_function protocol not supported in numpy <1.17") @@ -697,6 +710,13 @@ def test_check_estimator_get_tags_default_keys(): check_estimator_get_tags_default_keys(estimator.__class__.__name__, estimator) +def test_check_dataframe_column_names_consistency(): + err_msg = "Estimator does not have a feature_names_in_" + with raises(ValueError, match=err_msg): + check_dataframe_column_names_consistency("estimator_name", BaseBadClassifier()) + check_dataframe_column_names_consistency("estimator_name", PartialFitChecksName()) + + class _BaseMultiLabelClassifierMock(ClassifierMixin, BaseEstimator): def __init__(self, response_output): self.response_output = response_output diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 1a1449ecc209f..9d88a06149e61 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -49,6 +49,7 @@ _allclose_dense_sparse, _num_features, FLOAT_DTYPES, + _get_feature_names, ) from sklearn.utils.validation import _check_fit_params @@ -1445,3 +1446,59 @@ def test_check_array_deprecated_matrix(): ) with pytest.warns(FutureWarning, match=msg): check_array(X) + + +@pytest.mark.parametrize( + "names", + [list(range(2)), range(2), None], + ids=["list-int", "range", "default"], +) +def test_get_feature_names_pandas_with_ints_no_warning(names): + """Get feature names with pandas dataframes with ints without warning""" + pd = pytest.importorskip("pandas") + X = pd.DataFrame([[1, 2], [4, 5], [5, 6]], columns=names) + + with pytest.warns(None) as record: + names = _get_feature_names(X) + assert not record + assert names is None + + +def test_get_feature_names_pandas(): + """Get feature names with pandas dataframes.""" + pd = pytest.importorskip("pandas") + columns = [f"col_{i}" for i in range(3)] + X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=columns) + feature_names = _get_feature_names(X) + + assert_array_equal(feature_names, columns) + + +def test_get_feature_names_numpy(): + """Get feature names return None for numpy arrays.""" + X = np.array([[1, 2, 3], [4, 5, 6]]) + names = _get_feature_names(X) + assert names is None + + +# TODO: Convert to a error in 1.2 +@pytest.mark.parametrize( + "names, dtypes", + [ + ([["a", "b"], ["c", "d"]], "['tuple']"), + (["a", 1], "['int', 'str']"), + ], + ids=["multi-index", "mixed"], +) +def test_get_feature_names_invalid_dtypes_warns(names, dtypes): + """Get feature names warns when the feature names have mixed dtypes""" + pd = pytest.importorskip("pandas") + X = pd.DataFrame([[1, 2], [4, 5], [5, 6]], columns=names) + + msg = re.escape( + "Feature names only support names that are all strings. " + f"Got feature names with dtypes: {dtypes}. An error will be raised" + ) + with pytest.warns(FutureWarning, match=msg): + names = _get_feature_names(X) + assert names is None diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 98bf6ac8bdb6a..b5b485c5837ab 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1587,3 +1587,51 @@ def _check_fit_params(X, fit_params, indices=None): ) return fit_params_validated + + +def _get_feature_names(X): + """Get feature names from X. + + Support for other array containers should place its implementation here. + + Parameters + ---------- + X : {ndarray, dataframe} of shape (n_samples, n_features) + Array container to extract feature names. + + - pandas dataframe : The columns will be considered to be feature + names. If the dataframe contains non-string feature names, `None` is + returned. + - All other array containers will return `None`. + + Returns + ------- + names: ndarray or None + Feature names of `X`. Unrecognized array containers will return `None`. + """ + feature_names = None + + # extract feature names for support array containers + if hasattr(X, "columns"): + feature_names = np.asarray(X.columns) + + if feature_names is None or len(feature_names) == 0: + return + + types = sorted(t.__qualname__ for t in set(type(v) for v in feature_names)) + + # Warn when types are mixed. + # ints and strings do not warn + if len(types) > 1 or not (types[0].startswith("int") or types[0] == "str"): + # TODO: Convert to an error in 1.2 + warnings.warn( + "Feature names only support names that are all strings. " + f"Got feature names with dtypes: {types}. An error will be raised " + "in 1.2.", + FutureWarning, + ) + return + + # Only feature names of all strings are supported + if types[0] == "str": + return feature_names