Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0476ca8
modified: sklearn/pipeline.py
it176131 Dec 21, 2022
58c907a
modified: doc/whats_new/v1.3.rst
it176131 Dec 22, 2022
ec6df53
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Dec 22, 2022
cebfa4a
modified: doc/whats_new/v1.3.rst
it176131 Dec 22, 2022
e60a2d9
modified: doc/whats_new/v1.3.rst
it176131 Dec 22, 2022
617f267
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Dec 22, 2022
4fcfac9
modified: sklearn/pipeline.py
it176131 Dec 22, 2022
62c5267
modified: doc/whats_new/v1.3.rst
it176131 Dec 22, 2022
d5a1736
Merge branch 'main' into feature_union_feature_names_in_
it176131 Dec 28, 2022
d73327a
modified: sklearn/tests/test_pipeline.py
it176131 Dec 28, 2022
e955171
Merge branch 'feature_union_feature_names_in_' of https://github.com/…
it176131 Dec 28, 2022
bfbf4ec
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Dec 28, 2022
71a08f2
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Dec 29, 2022
0fb8c1c
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Dec 30, 2022
b3dfe3b
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Jan 2, 2023
1ec5000
Update doc/whats_new/v1.3.rst
it176131 Jan 3, 2023
af78798
Update sklearn/tests/test_pipeline.py
it176131 Jan 3, 2023
073daf7
modified: sklearn/tests/test_pipeline.py
it176131 Jan 3, 2023
b5c168e
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Jan 3, 2023
b4c46c1
Merge remote-tracking branch 'origin/main' into feature_union_feature…
it176131 Jan 3, 2023
e709372
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 3, 2023
b84e70d
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 3, 2023
6fe1508
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 3, 2023
c543c01
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 4, 2023
d2046be
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 4, 2023
6678442
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 9, 2023
69f99c9
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 10, 2023
a02306c
Merge branch 'scikit-learn:main' into feature_union_feature_names_in_
it176131 Jan 10, 2023
9402393
Merge branch 'scikit-learn:main' into feature_union_feature_names_in_
it176131 Jan 11, 2023
05d0ae9
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 12, 2023
905529f
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 12, 2023
4027147
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 13, 2023
835d47b
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 13, 2023
f26df87
Merge branch 'main' into feature_union_feature_names_in_
it176131 Jan 13, 2023
b0a102f
Merge branch 'main' into feature_union_feature_names_in_
glemaitre Jan 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ Changelog
`feature_union["scalar"]`) to access transformers by name. :pr:`25093` by
`Thomas Fan`_.

- |Feature| :class:`pipeline.FeatureUnion` can now access the
`feature_names_in_` attribute if the `X` value seen during `.fit` has a
`columns` attribute and all columns are strings. e.g. when `X` is a
`pandas.DataFrame`
:pr:`25220` by :user:`Ian Thompson <it176131>`.

:mod:`sklearn.preprocessing`
............................

Expand Down
12 changes: 12 additions & 0 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,12 @@ class FeatureUnion(TransformerMixin, _BaseComposition):

.. versionadded:: 0.24

feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Defined only when
`X` has feature names that are all strings.

.. versionadded:: 1.3

See Also
--------
make_union : Convenience function for simplified feature union
Expand Down Expand Up @@ -1297,6 +1303,12 @@ def n_features_in_(self):
# X is passed to all transformers so we just delegate to the first one
return self.transformer_list[0][1].n_features_in_

@property
def feature_names_in_(self):
"""Names of features seen during :term:`fit`."""
# X is passed to all transformers -- delegate to the first one
return self.transformer_list[0][1].feature_names_in_

def __sklearn_is_fitted__(self):
# Delegate whether feature union was fitted
for _, transformer, _ in self._iter():
Expand Down
32 changes: 32 additions & 0 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,35 @@ def test_feature_union_getitem_error(key):
msg = "Only string keys are supported"
with pytest.raises(KeyError, match=msg):
union[key]


def test_feature_union_feature_names_in_():
"""Ensure feature union has `.feature_names_in_` attribute if `X` has a
`columns` attribute.

Test for #24754.
"""
pytest.importorskip("pandas")

X, _ = load_iris(as_frame=True, return_X_y=True)

# FeatureUnion should have the feature_names_in_ attribute if the
# first transformer also has it
scaler = StandardScaler()
scaler.fit(X)
union = FeatureUnion([("scale", scaler)])
assert hasattr(union, "feature_names_in_")
assert_array_equal(X.columns, union.feature_names_in_)
assert_array_equal(scaler.feature_names_in_, union.feature_names_in_)

# fit with pandas.DataFrame
union = FeatureUnion([("pass", "passthrough")])
union.fit(X)
assert hasattr(union, "feature_names_in_")
assert_array_equal(X.columns, union.feature_names_in_)

# fit with numpy array
X_array = X.to_numpy()
union = FeatureUnion([("pass", "passthrough")])
union.fit(X_array)
assert not hasattr(union, "feature_names_in_")