-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Pipeline can now be sliced or indexed #2568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
53ff58b
d02a64a
7fa737d
d305273
ef266c1
48ee35e
a5024ca
7b21322
86dc37f
0840b84
f7d20ff
1b92159
d6e4146
6582b06
96509b1
210b26f
3733569
86bd075
3d06b24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,17 +99,28 @@ class Pipeline(_BaseComposition): | |
>>> anova_svm.score(X, y) # doctest: +ELLIPSIS | ||
0.83 | ||
>>> # getting the selected features chosen by anova_filter | ||
>>> anova_svm.named_steps['anova'].get_support() | ||
>>> anova_svm['anova'].get_support() | ||
... # doctest: +NORMALIZE_WHITESPACE | ||
array([False, False, True, True, False, False, True, True, False, | ||
True, False, True, True, False, True, False, True, True, | ||
array([False, False, True, True, False, False, True, True, False, | ||
True, False, True, True, False, True, False, True, True, | ||
False, False]) | ||
>>> # Another way to get selected features chosen by anova_filter | ||
>>> anova_svm.named_steps.anova.get_support() | ||
... # doctest: +NORMALIZE_WHITESPACE | ||
array([False, False, True, True, False, False, True, True, False, | ||
True, False, True, True, False, True, False, True, True, | ||
array([False, False, True, True, False, False, True, True, False, | ||
True, False, True, True, False, True, False, True, True, | ||
False, False]) | ||
>>> # Indexing can also be used to extract a sub-pipeline. | ||
>>> sub_pipeline = anova_svm[:1] | ||
>>> sub_pipeline # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | ||
Pipeline(memory=None, steps=[('anova', ...)]) | ||
>>> coef = anova_svm[-1].coef_ | ||
>>> anova_svm['svc'] is anova_svm[-1] | ||
True | ||
>>> coef.shape | ||
(1, 10) | ||
>>> sub_pipeline.inverse_transform(coef).shape | ||
(1, 20) | ||
""" | ||
|
||
# BaseEstimator interface | ||
|
@@ -188,6 +199,26 @@ def _iter(self, with_final=True): | |
if trans is not None and trans != 'passthrough': | ||
yield idx, name, trans | ||
|
||
def __getitem__(self, ind): | ||
"""Returns a sub-pipeline or a single esimtator in the pipeline | ||
|
||
Indexing with an integer will return an estimator; using a slice | ||
returns another Pipeline instance which copies a slice of this | ||
Pipeline. This copy is shallow: modifying (or fitting) estimators in | ||
the sub-pipeline will affect the larger pipeline and vice-versa. | ||
However, replacing a value in `step` will not affect a copy. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps clarify: replacing a value in step in the original pipeline instance of the sub-pipeline instance. |
||
""" | ||
if isinstance(ind, slice): | ||
if ind.step not in (1, None): | ||
raise ValueError('Pipeline slicing only supports a step of 1') | ||
return self.__class__(self.steps[ind]) | ||
try: | ||
name, est = self.steps[ind] | ||
except TypeError: | ||
# Not an int, try get step by name | ||
return self.named_steps[ind] | ||
return est | ||
|
||
@property | ||
def _estimator_type(self): | ||
return self.steps[-1][1]._estimator_type | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -529,6 +529,29 @@ def test_pipeline_fit_transform(): | |
assert_array_almost_equal(X_trans, X_trans2) | ||
|
||
|
||
def test_pipeline_slice(): | ||
pipe = Pipeline([('transf1', Transf()), | ||
('transf2', Transf()), | ||
('clf', FitParamT())]) | ||
pipe2 = pipe[:-1] | ||
assert isinstance(pipe2, Pipeline) | ||
assert pipe2.steps == pipe.steps[:-1] | ||
assert 2 == len(pipe2.named_steps) | ||
assert_raises(ValueError, lambda: pipe[::-1]) | ||
|
||
|
||
def test_pipeline_index(): | ||
transf = Transf() | ||
clf = FitParamT() | ||
pipe = Pipeline([('transf', transf), ('clf', clf)]) | ||
assert pipe[0] == transf | ||
assert pipe['transf'] == transf | ||
assert pipe[-1] == clf | ||
assert pipe['clf'] == clf | ||
assert_raises(IndexError, lambda: pipe[3]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps another test could be added for sub-pipeline index over several steps that exceeds the max. The present test gets at the case where a single estimator is returned, but not the case where a sub-pipeline is returned as a Pipeline() instance. |
||
assert_raises(KeyError, lambda: pipe['foobar']) | ||
|
||
|
||
def test_set_pipeline_steps(): | ||
transf1 = Transf() | ||
transf2 = Transf() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo in "estimator"