Skip to content

Commit f309ffb

Browse files
shubhraneelshubhraneelglemaitreogriselthomasjpfan
authored
ENH add support for 'passthrough' in FeatureUnion (#20860)
Co-authored-by: shubhraneel <shubhraneel@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 39c90db commit f309ffb

File tree

3 files changed

+81
-14
lines changed

3 files changed

+81
-14
lines changed

doc/whats_new/v1.1.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ Changelog
5353
:pr:`20880` by :user:`Guillaume Lemaitre <glemaitre>`
5454
and :user:`András Simon <simonandras>`.
5555

56+
:mod:`sklearn.pipeline`
57+
.......................
58+
59+
- |Enhancement| Added support for "passthrough" in :class:`FeatureUnion`.
60+
Setting a transformer to "passthrough" will pass the features unchanged.
61+
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.
5662

5763
Code and Documentation Contributors
5864
-----------------------------------

sklearn/pipeline.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from joblib import Parallel
1818

1919
from .base import clone, TransformerMixin
20+
from .preprocessing import FunctionTransformer
2021
from .utils._estimator_html_repr import _VisualBlock
2122
from .utils.metaestimators import available_if
2223
from .utils import (
@@ -914,21 +915,24 @@ class FeatureUnion(TransformerMixin, _BaseComposition):
914915
915916
Parameters of the transformers may be set using its name and the parameter
916917
name separated by a '__'. A transformer may be replaced entirely by
917-
setting the parameter with its name to another transformer,
918-
or removed by setting to 'drop'.
918+
setting the parameter with its name to another transformer, removed by
919+
setting to 'drop' or disabled by setting to 'passthrough' (features are
920+
passed without transformation).
919921
920922
Read more in the :ref:`User Guide <feature_union>`.
921923
922924
.. versionadded:: 0.13
923925
924926
Parameters
925927
----------
926-
transformer_list : list of tuple
927-
List of tuple containing `(str, transformer)`. The first element
928-
of the tuple is name affected to the transformer while the
929-
second element is a scikit-learn transformer instance.
930-
The transformer instance can also be `"drop"` for it to be
931-
ignored.
928+
transformer_list : list of (str, transformer) tuples
929+
List of transformer objects to be applied to the data. The first
930+
half of each tuple is the name of the transformer. The transformer can
931+
be 'drop' for it to be ignored or can be 'passthrough' for features to
932+
be passed unchanged.
933+
934+
.. versionadded:: 1.1
935+
Added the option `"passthrough"`.
932936
933937
.. versionchanged:: 0.22
934938
Deprecated `None` as a transformer in favor of 'drop'.
@@ -1038,7 +1042,7 @@ def _validate_transformers(self):
10381042

10391043
# validate estimators
10401044
for t in transformers:
1041-
if t == "drop":
1045+
if t in ("drop", "passthrough"):
10421046
continue
10431047
if not (hasattr(t, "fit") or hasattr(t, "fit_transform")) or not hasattr(
10441048
t, "transform"
@@ -1065,12 +1069,15 @@ def _iter(self):
10651069
Generate (name, trans, weight) tuples excluding None and
10661070
'drop' transformers.
10671071
"""
1072+
10681073
get_weight = (self.transformer_weights or {}).get
1069-
return (
1070-
(name, trans, get_weight(name))
1071-
for name, trans in self.transformer_list
1072-
if trans != "drop"
1073-
)
1074+
1075+
for name, trans in self.transformer_list:
1076+
if trans == "drop":
1077+
continue
1078+
if trans == "passthrough":
1079+
trans = FunctionTransformer()
1080+
yield (name, trans, get_weight(name))
10741081

10751082
@deprecated(
10761083
"get_feature_names is deprecated in 1.0 and will be removed "

sklearn/tests/test_pipeline.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,60 @@ def test_set_feature_union_step_drop(get_names):
10041004
assert not record
10051005

10061006

1007+
def test_set_feature_union_passthrough():
1008+
"""Check the behaviour of setting a transformer to `"passthrough"`."""
1009+
mult2 = Mult(2)
1010+
mult3 = Mult(3)
1011+
X = np.asarray([[1]])
1012+
1013+
ft = FeatureUnion([("m2", mult2), ("m3", mult3)])
1014+
assert_array_equal([[2, 3]], ft.fit(X).transform(X))
1015+
assert_array_equal([[2, 3]], ft.fit_transform(X))
1016+
1017+
ft.set_params(m2="passthrough")
1018+
assert_array_equal([[1, 3]], ft.fit(X).transform(X))
1019+
assert_array_equal([[1, 3]], ft.fit_transform(X))
1020+
1021+
ft.set_params(m3="passthrough")
1022+
assert_array_equal([[1, 1]], ft.fit(X).transform(X))
1023+
assert_array_equal([[1, 1]], ft.fit_transform(X))
1024+
1025+
# check we can change back
1026+
ft.set_params(m3=mult3)
1027+
assert_array_equal([[1, 3]], ft.fit(X).transform(X))
1028+
assert_array_equal([[1, 3]], ft.fit_transform(X))
1029+
1030+
# Check 'passthrough' step at construction time
1031+
ft = FeatureUnion([("m2", "passthrough"), ("m3", mult3)])
1032+
assert_array_equal([[1, 3]], ft.fit(X).transform(X))
1033+
assert_array_equal([[1, 3]], ft.fit_transform(X))
1034+
1035+
X = iris.data
1036+
columns = X.shape[1]
1037+
pca = PCA(n_components=2, svd_solver="randomized", random_state=0)
1038+
1039+
ft = FeatureUnion([("passthrough", "passthrough"), ("pca", pca)])
1040+
assert_array_equal(X, ft.fit(X).transform(X)[:, :columns])
1041+
assert_array_equal(X, ft.fit_transform(X)[:, :columns])
1042+
1043+
ft.set_params(pca="passthrough")
1044+
X_ft = ft.fit(X).transform(X)
1045+
assert_array_equal(X_ft, np.hstack([X, X]))
1046+
X_ft = ft.fit_transform(X)
1047+
assert_array_equal(X_ft, np.hstack([X, X]))
1048+
1049+
ft.set_params(passthrough=pca)
1050+
assert_array_equal(X, ft.fit(X).transform(X)[:, -columns:])
1051+
assert_array_equal(X, ft.fit_transform(X)[:, -columns:])
1052+
1053+
ft = FeatureUnion(
1054+
[("passthrough", "passthrough"), ("pca", pca)],
1055+
transformer_weights={"passthrough": 2},
1056+
)
1057+
assert_array_equal(X * 2, ft.fit(X).transform(X)[:, :columns])
1058+
assert_array_equal(X * 2, ft.fit_transform(X)[:, :columns])
1059+
1060+
10071061
def test_step_name_validation():
10081062
error_message_1 = r"Estimator names must not contain __: got \['a__q'\]"
10091063
error_message_2 = r"Names provided are not unique: \['a', 'a'\]"

0 commit comments

Comments
 (0)