Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
779385f
ENH Adds column name consistency
thomasjpfan Jul 27, 2020
a579392
BUG Fix
thomasjpfan Jul 27, 2020
74368fd
BUG Fix
thomasjpfan Jul 27, 2020
bce5d0f
STY Flake8
thomasjpfan Jul 27, 2020
2ca4dbf
MNT Adds xarray
thomasjpfan Jul 27, 2020
f6048d7
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Sep 1, 2020
485b5ca
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Sep 2, 2020
df6d193
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Sep 3, 2020
7465ec2
CLN Smaller diff
thomasjpfan Sep 4, 2020
4d5c3d4
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Sep 4, 2020
19583bc
CLN Smaller diff
thomasjpfan Sep 4, 2020
cb3e6be
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Sep 4, 2020
4d0840a
TST Adds test for feature_names_in
thomasjpfan Sep 4, 2020
53270fe
ENH Adds tests for coverage
thomasjpfan Sep 4, 2020
4f7c5e2
TST Fixes warning message
thomasjpfan Sep 4, 2020
37117eb
ENH Adds xarray
thomasjpfan Sep 4, 2020
5ed789b
BLD Force build on ci
thomasjpfan Sep 4, 2020
ee09732
Merge branch 'master' into column_name_consistent
ogrisel Sep 28, 2020
7564e75
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Sep 30, 2020
df146d1
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Oct 1, 2020
1b81d12
CLN Changes naming
thomasjpfan Oct 1, 2020
3522f37
TST Adds more testing
thomasjpfan Oct 1, 2020
a81e2a3
WIP Address comments
thomasjpfan Oct 1, 2020
2c45b65
ENH Only update cross_decompositon
thomasjpfan Oct 1, 2020
f43356b
TST Adds test for partial_fit
thomasjpfan Oct 1, 2020
8efb395
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Oct 6, 2020
46f332d
TST Fixes
thomasjpfan Oct 6, 2020
c93bd9d
DOC More cleanup
thomasjpfan Oct 6, 2020
5039f5a
DOC Adds docstring
thomasjpfan Oct 6, 2020
ee03ab7
TST Adds smoke test
thomasjpfan Oct 6, 2020
be480b6
ENH Restricts column names to string
thomasjpfan Oct 6, 2020
8c5f425
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Oct 6, 2020
3dd4041
Merge remote-tracking branch 'upstream/master' into column_name_consi…
thomasjpfan Oct 7, 2020
670996a
CLN Removes reset
thomasjpfan Oct 7, 2020
a4af9c0
CLN Reduce diffs
thomasjpfan Oct 7, 2020
afae3fc
Merge branch 'master' into column_name_consistent
ogrisel Oct 14, 2020
f57dcb0
Merge remote-tracking branch 'upstream/main' into column_name_consistent
thomasjpfan Jun 28, 2021
f853336
STY Runs black
thomasjpfan Jun 28, 2021
aeb220e
CLN Slightly nicer tests
thomasjpfan Jun 28, 2021
a2ce8b2
CLN Does not define feature_names_in_ when they do not exist
thomasjpfan Jun 28, 2021
f344353
CLN Be more strict about dataframes
thomasjpfan Jun 28, 2021
f114e98
CLN Sync with data generation in check_n_features_in_after_fitting
thomasjpfan Jun 28, 2021
396d3ea
ENH Removes estimators from ignore list
thomasjpfan Jun 28, 2021
8372a6e
ENH Removes estimators from ignore list
thomasjpfan Jun 28, 2021
dd36120
DOC Adds whats new
thomasjpfan Jun 28, 2021
bacbec1
BUG Fix typo bug
thomasjpfan Jun 28, 2021
d86e70e
CLN Address comments
thomasjpfan Jul 9, 2021
8a4212f
CLN Address comments
thomasjpfan Jul 9, 2021
ee95642
DOC grammer
thomasjpfan Jul 9, 2021
7aeae36
TST Adds test and remove support for multiindex
thomasjpfan Jul 10, 2021
4c79f62
Merge remote-tracking branch 'upstream/main' into column_name_consistent
thomasjpfan Jul 10, 2021
2e4e422
STY Run black
thomasjpfan Jul 10, 2021
d112c61
ENH Adds warnings to _check_feature_names
thomasjpfan Jul 10, 2021
59a1cc1
Merge remote-tracking branch 'upstream/main' into column_name_consistent
thomasjpfan Jul 17, 2021
9494a1f
Merge remote-tracking branch 'upstream/main' into column_name_consistent
thomasjpfan Jul 23, 2021
f70a56e
CLN Address comments
thomasjpfan Jul 23, 2021
62bb28b
ENH Adds better errors
thomasjpfan Jul 23, 2021
a0d4d12
ENH Warn for non-string columns
thomasjpfan Jul 24, 2021
7bc3d8b
CLN Move back to validation
thomasjpfan Jul 24, 2021
6e42e0f
TST Fix test error message
thomasjpfan Jul 25, 2021
a171666
ENH Delegates validation to base estimator in Ransac
thomasjpfan Jul 25, 2021
681d045
Revert "ENH Delegates validation to base estimator in Ransac"
thomasjpfan Jul 26, 2021
673811a
ENH Only warn in fit
thomasjpfan Jul 26, 2021
1a2bf25
TST Makes test pass
thomasjpfan Aug 1, 2021
ef3657c
Merge remote-tracking branch 'upstream/main' into column_name_consistent
thomasjpfan Aug 6, 2021
7b1999e
Merge branch 'main' into column_name_consistent
ogrisel Aug 6, 2021
966f50f
Merge remote-tracking branch 'upstream/main' into column_name_consistent
thomasjpfan Aug 13, 2021
decb967
CLN Address comments
thomasjpfan Aug 13, 2021
5e89f6c
ENH Include 1.2 error in warning
thomasjpfan Aug 13, 2021
77221fd
CLN Better checking
thomasjpfan Aug 13, 2021
19de717
ENH Super restrictive on supporting ints and strs
thomasjpfan Aug 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ Changelog
- |API| `np.matrix` usage is deprecated in 1.0 and will raise a `TypeError` in
1.2. :pr:`20165` by `Thomas Fan`_.

- |API| All estimators store `feature_names_in_` when fitted on pandas Dataframes.
These feature names are compared to names seen in `non-fit` methods,
`i.e.` `transform` and will raise a `FutureWarning` if they are not consistent.
These `FutureWarning`s will become `ValueError`s in 1.2.
:pr:`18010` by `Thomas Fan`_.

:mod:`sklearn.base`
...................

Expand Down
89 changes: 89 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .utils.validation import _check_y
from .utils.validation import _num_features
from .utils._estimator_html_repr import estimator_html_repr
from .utils.validation import _get_feature_names


def clone(estimator, *, safe=True):
Expand Down Expand Up @@ -395,6 +396,92 @@ def _check_n_features(self, X, reset):
f"is expecting {self.n_features_in_} features as input."
)

def _check_feature_names(self, X, *, reset):
"""Set or check the `feature_names_in_` attribute.

.. versionadded:: 1.0

Parameters
----------
X : {ndarray, dataframe} of shape (n_samples, n_features)
The input samples.

reset : bool
Whether to reset the `feature_names_in_` attribute.
If False, the input will be checked for consistency with
feature names of data provided when reset was last True.
.. note::
It is recommended to call `reset=True` in `fit` and in the first
call to `partial_fit`. All other methods that validate `X`
should set `reset=False`.
"""

if reset:
feature_names_in = _get_feature_names(X)
if feature_names_in is not None:
self.feature_names_in_ = feature_names_in
return

fitted_feature_names = getattr(self, "feature_names_in_", None)
X_feature_names = _get_feature_names(X)

if fitted_feature_names is None and X_feature_names is None:
# no feature names seen in fit and in X
return

if X_feature_names is not None and fitted_feature_names is None:
warnings.warn(
f"X has feature names, but {self.__class__.__name__} was fitted without"
" feature names"
)
return

if X_feature_names is None and fitted_feature_names is not None:
warnings.warn(
"X does not have valid feature names, but"
f" {self.__class__.__name__} was fitted with feature names"
)
return

# validate the feature names against the `feature_names_in_` attribute
if len(fitted_feature_names) != len(X_feature_names) or np.any(
fitted_feature_names != X_feature_names
):
message = (
"The feature names should match those that were "
"passed during fit. Starting version 1.2, an error will be raised.\n"
)
fitted_feature_names_set = set(fitted_feature_names)
X_feature_names_set = set(X_feature_names)

unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set)
missing_names = sorted(fitted_feature_names_set - X_feature_names_set)

def add_names(names):
output = ""
max_n_names = 5
for i, name in enumerate(names):
if i >= max_n_names:
output += "- ...\n"
break
output += f"- {name}\n"
return output

if unexpected_names:
message += "Feature names unseen at fit time:\n"
message += add_names(unexpected_names)

if missing_names:
message += "Feature names seen at fit time, yet now missing:\n"
message += add_names(missing_names)

if not missing_names and not missing_names:
message += (
"Feature names must be in the same order as they were in fit.\n"
)
Comment on lines +478 to +481
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a typo? I guess the intended line 478 is
if not missing_names and not unexpected_names:

Copy link
Member Author

@thomasjpfan thomasjpfan Apr 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's a typo. Are you interested in opening a PR to fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! (It is difficult to say no at this point!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request created: #23091


warnings.warn(message, FutureWarning)

def _validate_data(
self,
X="no_validation",
Expand Down Expand Up @@ -452,6 +539,8 @@ def _validate_data(
The validated input. A tuple is returned if both `X` and `y` are
validated.
"""
self._check_feature_names(X, reset=reset)

if y is None and self._get_tags()["requires_y"]:
raise ValueError(
f"This {self.__class__.__name__} estimator "
Expand Down
2 changes: 2 additions & 0 deletions sklearn/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def fit(self, X, y, sample_weight=None):
first_clf = self.calibrated_classifiers_[0].base_estimator
if hasattr(first_clf, "n_features_in_"):
self.n_features_in_ = first_clf.n_features_in_
if hasattr(first_clf, "feature_names_in_"):
self.feature_names_in_ = first_clf.feature_names_in_
return self

def predict_proba(self, X):
Expand Down
2 changes: 2 additions & 0 deletions sklearn/feature_selection/_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def fit(self, X, y=None, **fit_params):
raise NotFittedError("Since 'prefit=True', call transform directly")
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X, y, **fit_params)
if hasattr(self.estimator_, "feature_names_in_"):
self.feature_names_in_ = self.estimator_.feature_names_in_
return self

@property
Expand Down
8 changes: 4 additions & 4 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .base import BaseEstimator
from .base import TransformerMixin
from .utils import check_random_state, as_float_array
from .utils import check_random_state
from .utils.extmath import safe_sparse_dot
from .utils.validation import check_is_fitted
from .metrics.pairwise import pairwise_kernels, KERNEL_PARAMS
Expand Down Expand Up @@ -450,9 +450,9 @@ def transform(self, X):
Projected array.
"""
check_is_fitted(self)

X = as_float_array(X, copy=True)
X = self._validate_data(X, copy=False, reset=False)
X = self._validate_data(
X, copy=True, dtype=[np.float64, np.float32], reset=False
)
if (X <= -self.skewedness).any():
raise ValueError("X may not contain entries smaller than -skewedness.")

Expand Down
2 changes: 2 additions & 0 deletions sklearn/linear_model/_ransac.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def predict(self, X):
Returns predicted values.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)

return self.estimator_.predict(X)

Expand All @@ -561,6 +562,7 @@ def score(self, X, y):
Score of the prediction.
"""
check_is_fitted(self)
self._check_feature_names(X, reset=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here.


return self.estimator_.score(X, y)

Expand Down
2 changes: 2 additions & 0 deletions sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,8 @@ def fit(self, X, y, sample_weight=None):
self.coef_ = estimator.coef_
self.intercept_ = estimator.intercept_
self.n_features_in_ = estimator.n_features_in_
if hasattr(estimator, "feature_names_in_"):
self.feature_names_in_ = estimator.feature_names_in_

return self

Expand Down
2 changes: 2 additions & 0 deletions sklearn/manifold/_isomap.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def _fit_transform(self, X):
)
self.nbrs_.fit(X)
self.n_features_in_ = self.nbrs_.n_features_in_
if hasattr(self.nbrs_, "feature_names_in_"):
self.feature_names_in_ = self.nbrs_.feature_names_in_

self.kernel_pca_ = KernelPCA(
n_components=self.n_components,
Expand Down
2 changes: 1 addition & 1 deletion sklearn/manifold/_locally_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def transform(self, X):
"""
check_is_fitted(self)

X = check_array(X)
X = self._validate_data(X, reset=False)
ind = self.nbrs_.kneighbors(
X, n_neighbors=self.n_neighbors, return_distance=False
)
Expand Down
3 changes: 1 addition & 2 deletions sklearn/neural_network/_rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from ..base import BaseEstimator
from ..base import TransformerMixin
from ..utils import check_array
from ..utils import check_random_state
from ..utils import gen_even_slices
from ..utils.extmath import safe_sparse_dot
Expand Down Expand Up @@ -333,7 +332,7 @@ def score_samples(self, X):
"""
check_is_fitted(self)

v = check_array(X, accept_sparse="csr")
v = self._validate_data(X, accept_sparse="csr", reset=False)
rng = check_random_state(self.random_state)

# Randomly corrupt one feature in each sample in v.
Expand Down
71 changes: 71 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Author: Gael Varoquaux
# License: BSD 3 clause

import re
import numpy as np
import scipy.sparse as sp
import pytest
Expand Down Expand Up @@ -615,3 +616,73 @@ def test_n_features_in_no_validation():

# does not raise
est._check_n_features("invalid X", reset=False)


def test_feature_names_in():
"""Check that feature_name_in are recorded by `_validate_data`"""
pd = pytest.importorskip("pandas")
iris = datasets.load_iris()
X_np = iris.data
df = pd.DataFrame(X_np, columns=iris.feature_names)

class NoOpTransformer(TransformerMixin, BaseEstimator):
def fit(self, X, y=None):
self._validate_data(X)
return self

def transform(self, X):
self._validate_data(X, reset=False)
return X

# fit on dataframe saves the feature names
trans = NoOpTransformer().fit(df)
assert_array_equal(trans.feature_names_in_, df.columns)

msg = "The feature names should match those that were passed"
df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1])
with pytest.warns(FutureWarning, match=msg):
trans.transform(df_bad)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expand the error message to be more descriptive about the missing feature names as suggested above, this part of the test needs to be expanded accordingly.


# warns when fitted on dataframe and transforming a ndarray
msg = (
"X does not have valid feature names, but NoOpTransformer was "
"fitted with feature names"
)
with pytest.warns(UserWarning, match=msg):
trans.transform(X_np)

# warns when fitted on a ndarray and transforming dataframe
msg = "X has feature names, but NoOpTransformer was fitted without feature names"
trans = NoOpTransformer().fit(X_np)
with pytest.warns(UserWarning, match=msg):
trans.transform(df)

# fit on dataframe with all integer feature names works without warning
df_int_names = pd.DataFrame(X_np)
trans = NoOpTransformer()
with pytest.warns(None) as record:
trans.fit(df_int_names)
assert not record

# fit on dataframe with no feature names or all integer feature names
# -> do not warn on trainsform
Xs = [X_np, df_int_names]
for X in Xs:
with pytest.warns(None) as record:
trans.transform(X)
assert not record

# TODO: Convert to a error in 1.2
# fit on dataframe with feature names that are mixed warns:
df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2])
trans = NoOpTransformer()
msg = re.escape(
"Feature names only support names that are all strings. "
"Got feature names with dtypes: ['int', 'str']"
)
with pytest.warns(FutureWarning, match=msg) as record:
trans.fit(df_mixed)

# transform on feature names that are mixed also warns:
with pytest.warns(FutureWarning, match=msg) as record:
trans.transform(df_mixed)
39 changes: 39 additions & 0 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_get_check_estimator_ids,
check_class_weight_balanced_linear_classifier,
parametrize_with_checks,
check_dataframe_column_names_consistency,
check_n_features_in_after_fitting,
)

Expand Down Expand Up @@ -313,3 +314,41 @@ def test_search_cv(estimator, check, request):
def test_check_n_features_in_after_fitting(estimator):
_set_checking_parameters(estimator)
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)


# TODO: When more modules get added, we can remove it from this list to make
# sure it gets tested. After we finish each module we can move the checks
# into check_estimator.
# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that
# delegates validation to a base estimator, the check is testing that the base estimator
# is checking for column name consistency.

COLUMN_NAME_MODULES_TO_IGNORE = {
"compose",
"ensemble",
"feature_extraction",
"kernel_approximation",
"model_selection",
"multiclass",
"multioutput",
"pipeline",
"semi_supervised",
}


column_name_estimators = [
est
for est in _tested_estimators()
if est.__module__.split(".")[1] not in COLUMN_NAME_MODULES_TO_IGNORE
]


@pytest.mark.parametrize(
"estimator", column_name_estimators, ids=_get_check_estimator_ids
)
def test_pandas_column_name_consistency(estimator):
_set_checking_parameters(estimator)
with ignore_warnings(category=(FutureWarning)):
check_dataframe_column_names_consistency(
estimator.__class__.__name__, estimator
)
Loading