Skip to content

Pipeline pop #8448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
# Virgile Fritsch
# Alexandre Gramfort
# Lars Buitinck
# Joel Nothman
# Guillaume Lemaitre
# License: BSD

from collections import defaultdict
from abc import ABCMeta, abstractmethod
import copy

import numpy as np
from scipy import sparse
Expand Down Expand Up @@ -233,6 +236,43 @@ def named_steps(self):
def _final_estimator(self):
return self.steps[-1][1]

def pop(self, pos=-1):
"""Return the pipeline without a given step, and that step's estimator

This is most often used to extract a pipeline consisting of the fitted
transformers preceding the last step, as well as the last step, to aid
model inspection.

Parameters
----------
pos : int or str, optional
The index (0-based) or name of the step to be popped. Defaults to
the final step.

Returns
-------
sub_pipeline : Pipeline instance
This pipeline is a copy of ``self``, with the step at ``pos``
removed. The constituent estimators are not copied: if the
Pipeline had been fit, so will be the returned Pipeline.

The return type will be of the same type as self, if a subclass
is used and if its constructor is compatible.

estimator : estimator instance
The estimator found at ``pos`` within this Pipeline's steps.
This is not copied and remains fitted if it was previously.
"""
if isinstance(pos, six.string_types): # retrieve index for name ...
pos = [name for name, _ in self.steps].index(pos)
elif pos < 0:
pos = len(self.steps) + pos

# shallow copy of pipeline instance
out = copy.copy(self)
out.steps = self.steps[:pos] + self.steps[pos + 1:]
return out, self.steps[pos][1]

# Estimator interface

def _fit(self, X, y=None, **fit_params):
Expand Down
61 changes: 61 additions & 0 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,3 +894,64 @@ def test_pipeline_memory():
assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_)
finally:
shutil.rmtree(cachedir)


class MyPipeline(Pipeline):
pass


class MyPipelineNoMemory(Pipeline):
def __init__(self, steps, other_param):
super(MyPipelineNoMemory, self).__init__(steps)
self.other_param = other_param


def test_pipeline_pop():
pipe = Pipeline([('transf1', Transf()),
('transf2', Transf()),
('predict', Mult())])
pipe.fit(np.arange(5)[:, None], np.arange(5))

for pos, idx in [
(0, 0),
(1, 1),
(2, 2),
(-3, 0),
(-2, 1),
(-1, 2),
('transf1', 0),
('transf2', 1),
('predict', 2),
]:
print(pos, idx)
new_pipe, popped_est = pipe.pop(pos)
assert_equal(len(pipe.steps) - 1, len(new_pipe.steps))
expected_steps = pipe.steps[:idx] + pipe.steps[idx + 1:]
assert_equal(new_pipe.steps, expected_steps)
assert_dict_equal(new_pipe.named_steps, dict(expected_steps))
for name in new_pipe.named_steps:
assert_true(new_pipe.named_steps[name] is pipe.named_steps[name])

assert_true(popped_est is pipe.steps[idx][1])

# invalid step name
assert_raise_message(ValueError, "'foo' is not in list",
pipe.pop, 'foo')

# test subtype is maintained by pop
for memory in [None, '/path/to/somewhere']:
pipe = MyPipeline([('transf1', Transf()),
('predict', Mult())],
memory=memory)
new_pipe, _ = pipe.pop()
assert_equal(type(new_pipe), type(pipe))
assert_equal(new_pipe.steps, pipe.steps[:-1])
assert_equal(pipe.memory, new_pipe.memory)

# test subtype with different constructor signature
pipe = MyPipelineNoMemory([('transf1', Transf()),
('predict', Mult())],
other_param='blah')
new_pipe, _ = pipe.pop()
assert_equal(new_pipe.steps, pipe.steps[:-1])
assert_equal(pipe.other_param, new_pipe.other_param)