Skip to content

Commit aae8700

Browse files
jorisvandenbosschejnothman
authored andcommitted
FIX force pipeline steps to be list not a tuple (#9604)
1 parent 89b02af commit aae8700

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

sklearn/pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .base import clone, TransformerMixin
1818
from .externals.joblib import Parallel, delayed, Memory
1919
from .externals import six
20-
from .utils import tosequence
2120
from .utils.metaestimators import if_delegate_has_method
2221
from .utils import Bunch
2322

@@ -112,7 +111,7 @@ class Pipeline(_BaseComposition):
112111

113112
def __init__(self, steps, memory=None):
114113
# shallow copy of steps
115-
self.steps = tosequence(steps)
114+
self.steps = list(steps)
116115
self._validate_steps()
117116
self.memory = memory
118117

@@ -624,7 +623,7 @@ class FeatureUnion(_BaseComposition, TransformerMixin):
624623
625624
"""
626625
def __init__(self, transformer_list, n_jobs=1, transformer_weights=None):
627-
self.transformer_list = tosequence(transformer_list)
626+
self.transformer_list = list(transformer_list)
628627
self.n_jobs = n_jobs
629628
self.transformer_weights = transformer_weights
630629
self._validate_transformers()

sklearn/tests/test_pipeline.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,18 @@ def test_pipeline_init():
208208
assert_equal(params, params2)
209209

210210

211+
def test_pipeline_init_tuple():
212+
# Pipeline accepts steps as tuple
213+
X = np.array([[1, 2]])
214+
pipe = Pipeline((('transf', Transf()), ('clf', FitParamT())))
215+
pipe.fit(X, y=None)
216+
pipe.score(X)
217+
218+
pipe.set_params(transf=None)
219+
pipe.fit(X, y=None)
220+
pipe.score(X)
221+
222+
211223
def test_pipeline_methods_anova():
212224
# Test the various methods of the pipeline (anova).
213225
iris = load_iris()
@@ -425,6 +437,10 @@ def test_feature_union():
425437
FeatureUnion,
426438
[("transform", Transf()), ("no_transform", NoTrans())])
427439

440+
# test that init accepts tuples
441+
fs = FeatureUnion((("svd", svd), ("select", select)))
442+
fs.fit(X, y)
443+
428444

429445
def test_make_union():
430446
pca = PCA(svd_solver='full')

sklearn/utils/metaestimators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _set_params(self, attr, **params):
5151

5252
def _replace_estimator(self, attr, name, new_val):
5353
# assumes `name` is a valid estimator name
54-
new_estimators = getattr(self, attr)[:]
54+
new_estimators = list(getattr(self, attr))
5555
for i, (estimator_name, _) in enumerate(new_estimators):
5656
if estimator_name == name:
5757
new_estimators[i] = (name, new_val)

0 commit comments

Comments
 (0)