-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
ENH Adds Column name consistency #18010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
779385f
a579392
74368fd
bce5d0f
2ca4dbf
f6048d7
485b5ca
df6d193
7465ec2
4d5c3d4
19583bc
cb3e6be
4d0840a
53270fe
4f7c5e2
37117eb
5ed789b
ee09732
7564e75
df146d1
1b81d12
3522f37
a81e2a3
2c45b65
f43356b
8efb395
46f332d
c93bd9d
5039f5a
ee03ab7
be480b6
8c5f425
3dd4041
670996a
a4af9c0
afae3fc
f57dcb0
f853336
aeb220e
a2ce8b2
f344353
f114e98
396d3ea
8372a6e
dd36120
bacbec1
d86e70e
8a4212f
ee95642
7aeae36
4c79f62
2e4e422
d112c61
59a1cc1
9494a1f
f70a56e
62bb28b
a0d4d12
7bc3d8b
6e42e0f
a171666
681d045
673811a
1a2bf25
ef3657c
7b1999e
966f50f
decb967
5e89f6c
77221fd
19de717
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
from .utils.validation import _check_y | ||
from .utils.validation import _num_features | ||
from .utils._estimator_html_repr import estimator_html_repr | ||
from .utils.validation import _get_feature_names | ||
|
||
|
||
def clone(estimator, *, safe=True): | ||
|
@@ -395,6 +396,92 @@ def _check_n_features(self, X, reset): | |
f"is expecting {self.n_features_in_} features as input." | ||
) | ||
|
||
def _check_feature_names(self, X, *, reset): | ||
"""Set or check the `feature_names_in_` attribute. | ||
|
||
.. versionadded:: 1.0 | ||
|
||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Parameters | ||
---------- | ||
X : {ndarray, dataframe} of shape (n_samples, n_features) | ||
The input samples. | ||
|
||
reset : bool | ||
Whether to reset the `feature_names_in_` attribute. | ||
If False, the input will be checked for consistency with | ||
feature names of data provided when reset was last True. | ||
.. note:: | ||
It is recommended to call `reset=True` in `fit` and in the first | ||
call to `partial_fit`. All other methods that validate `X` | ||
should set `reset=False`. | ||
glemaitre marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
thomasjpfan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if reset: | ||
feature_names_in = _get_feature_names(X) | ||
if feature_names_in is not None: | ||
self.feature_names_in_ = feature_names_in | ||
return | ||
|
||
fitted_feature_names = getattr(self, "feature_names_in_", None) | ||
X_feature_names = _get_feature_names(X) | ||
|
||
if fitted_feature_names is None and X_feature_names is None: | ||
# no feature names seen in fit and in X | ||
return | ||
|
||
if X_feature_names is not None and fitted_feature_names is None: | ||
warnings.warn( | ||
f"X has feature names, but {self.__class__.__name__} was fitted without" | ||
" feature names" | ||
) | ||
return | ||
|
||
if X_feature_names is None and fitted_feature_names is not None: | ||
warnings.warn( | ||
"X does not have valid feature names, but" | ||
f" {self.__class__.__name__} was fitted with feature names" | ||
) | ||
return | ||
|
||
# validate the feature names against the `feature_names_in_` attribute | ||
if len(fitted_feature_names) != len(X_feature_names) or np.any( | ||
fitted_feature_names != X_feature_names | ||
): | ||
message = ( | ||
"The feature names should match those that were " | ||
"passed during fit. Starting version 1.2, an error will be raised.\n" | ||
) | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fitted_feature_names_set = set(fitted_feature_names) | ||
X_feature_names_set = set(X_feature_names) | ||
|
||
unexpected_names = sorted(X_feature_names_set - fitted_feature_names_set) | ||
missing_names = sorted(fitted_feature_names_set - X_feature_names_set) | ||
|
||
def add_names(names): | ||
output = "" | ||
max_n_names = 5 | ||
for i, name in enumerate(names): | ||
if i >= max_n_names: | ||
output += "- ...\n" | ||
break | ||
output += f"- {name}\n" | ||
return output | ||
|
||
if unexpected_names: | ||
message += "Feature names unseen at fit time:\n" | ||
message += add_names(unexpected_names) | ||
|
||
if missing_names: | ||
message += "Feature names seen at fit time, yet now missing:\n" | ||
message += add_names(missing_names) | ||
|
||
if not missing_names and not missing_names: | ||
message += ( | ||
"Feature names must be in the same order as they were in fit.\n" | ||
) | ||
Comment on lines
+478
to
+481
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a typo? I guess the intended line 478 is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, it's a typo. Are you interested in opening a PR to fix? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! (It is difficult to say no at this point!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pull request created: #23091 |
||
|
||
warnings.warn(message, FutureWarning) | ||
|
||
def _validate_data( | ||
self, | ||
X="no_validation", | ||
|
@@ -452,6 +539,8 @@ def _validate_data( | |
The validated input. A tuple is returned if both `X` and `y` are | ||
validated. | ||
""" | ||
self._check_feature_names(X, reset=reset) | ||
|
||
if y is None and self._get_tags()["requires_y"]: | ||
raise ValueError( | ||
f"This {self.__class__.__name__} estimator " | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -539,6 +539,7 @@ def predict(self, X): | |
Returns predicted values. | ||
""" | ||
check_is_fitted(self) | ||
self._check_feature_names(X, reset=False) | ||
|
||
return self.estimator_.predict(X) | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
@@ -561,6 +562,7 @@ def score(self, X, y): | |
Score of the prediction. | ||
""" | ||
check_is_fitted(self) | ||
self._check_feature_names(X, reset=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
|
||
return self.estimator_.score(X, y) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
# Author: Gael Varoquaux | ||
# License: BSD 3 clause | ||
|
||
import re | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import pytest | ||
|
@@ -615,3 +616,73 @@ def test_n_features_in_no_validation(): | |
|
||
# does not raise | ||
est._check_n_features("invalid X", reset=False) | ||
|
||
|
||
def test_feature_names_in(): | ||
"""Check that feature_name_in are recorded by `_validate_data`""" | ||
pd = pytest.importorskip("pandas") | ||
iris = datasets.load_iris() | ||
X_np = iris.data | ||
df = pd.DataFrame(X_np, columns=iris.feature_names) | ||
|
||
class NoOpTransformer(TransformerMixin, BaseEstimator): | ||
def fit(self, X, y=None): | ||
self._validate_data(X) | ||
return self | ||
|
||
def transform(self, X): | ||
self._validate_data(X, reset=False) | ||
return X | ||
|
||
# fit on dataframe saves the feature names | ||
trans = NoOpTransformer().fit(df) | ||
assert_array_equal(trans.feature_names_in_, df.columns) | ||
|
||
msg = "The feature names should match those that were passed" | ||
df_bad = pd.DataFrame(X_np, columns=iris.feature_names[::-1]) | ||
with pytest.warns(FutureWarning, match=msg): | ||
trans.transform(df_bad) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we expand the error message to be more descriptive about the missing feature names as suggested above, this part of the test needs to be expanded accordingly. |
||
|
||
# warns when fitted on dataframe and transforming a ndarray | ||
msg = ( | ||
"X does not have valid feature names, but NoOpTransformer was " | ||
"fitted with feature names" | ||
) | ||
with pytest.warns(UserWarning, match=msg): | ||
trans.transform(X_np) | ||
|
||
# warns when fitted on a ndarray and transforming dataframe | ||
msg = "X has feature names, but NoOpTransformer was fitted without feature names" | ||
trans = NoOpTransformer().fit(X_np) | ||
with pytest.warns(UserWarning, match=msg): | ||
trans.transform(df) | ||
|
||
# fit on dataframe with all integer feature names works without warning | ||
df_int_names = pd.DataFrame(X_np) | ||
trans = NoOpTransformer() | ||
with pytest.warns(None) as record: | ||
trans.fit(df_int_names) | ||
assert not record | ||
|
||
# fit on dataframe with no feature names or all integer feature names | ||
# -> do not warn on trainsform | ||
Xs = [X_np, df_int_names] | ||
for X in Xs: | ||
with pytest.warns(None) as record: | ||
trans.transform(X) | ||
assert not record | ||
|
||
# TODO: Convert to a error in 1.2 | ||
# fit on dataframe with feature names that are mixed warns: | ||
df_mixed = pd.DataFrame(X_np, columns=["a", "b", 1, 2]) | ||
trans = NoOpTransformer() | ||
msg = re.escape( | ||
"Feature names only support names that are all strings. " | ||
"Got feature names with dtypes: ['int', 'str']" | ||
) | ||
with pytest.warns(FutureWarning, match=msg) as record: | ||
trans.fit(df_mixed) | ||
|
||
# transform on feature names that are mixed also warns: | ||
with pytest.warns(FutureWarning, match=msg) as record: | ||
trans.transform(df_mixed) |
Uh oh!
There was an error while loading. Please reload this page.