diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 61d7b12b7564d..d72c937f5e7be 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -7,10 +7,13 @@ # Virgile Fritsch # Alexandre Gramfort # Lars Buitinck +# Joel Nothman +# Guillaume Lemaitre # License: BSD from collections import defaultdict from abc import ABCMeta, abstractmethod +import copy import numpy as np from scipy import sparse @@ -233,6 +236,43 @@ def named_steps(self): def _final_estimator(self): return self.steps[-1][1] + def pop(self, pos=-1): + """Return the pipeline without a given step, and that step's estimator + + This is most often used to extract a pipeline consisting of the fitted + transformers preceding the last step, as well as the last step, to aid + model inspection. + + Parameters + ---------- + pos : int or str, optional + The index (0-based) or name of the step to be popped. Defaults to + the final step. + + Returns + ------- + sub_pipeline : Pipeline instance + This pipeline is a copy of ``self``, with the step at ``pos`` + removed. The constituent estimators are not copied: if the + Pipeline had been fit, so will be the returned Pipeline. + + The return type will be of the same type as self, if a subclass + is used and if its constructor is compatible. + + estimator : estimator instance + The estimator found at ``pos`` within this Pipeline's steps. + This is not copied and remains fitted if it was previously. + """ + if isinstance(pos, six.string_types): # retrieve index for name ... + pos = [name for name, _ in self.steps].index(pos) + elif pos < 0: + pos = len(self.steps) + pos + + # shallow copy of pipeline instance + out = copy.copy(self) + out.steps = self.steps[:pos] + self.steps[pos + 1:] + return out, self.steps[pos][1] + # Estimator interface def _fit(self, X, y=None, **fit_params): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 33e3128931aff..fa72c1299147c 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -894,3 +894,64 @@ def test_pipeline_memory(): assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_) finally: shutil.rmtree(cachedir) + + +class MyPipeline(Pipeline): + pass + + +class MyPipelineNoMemory(Pipeline): + def __init__(self, steps, other_param): + super(MyPipelineNoMemory, self).__init__(steps) + self.other_param = other_param + + +def test_pipeline_pop(): + pipe = Pipeline([('transf1', Transf()), + ('transf2', Transf()), + ('predict', Mult())]) + pipe.fit(np.arange(5)[:, None], np.arange(5)) + + for pos, idx in [ + (0, 0), + (1, 1), + (2, 2), + (-3, 0), + (-2, 1), + (-1, 2), + ('transf1', 0), + ('transf2', 1), + ('predict', 2), + ]: + print(pos, idx) + new_pipe, popped_est = pipe.pop(pos) + assert_equal(len(pipe.steps) - 1, len(new_pipe.steps)) + expected_steps = pipe.steps[:idx] + pipe.steps[idx + 1:] + assert_equal(new_pipe.steps, expected_steps) + assert_dict_equal(new_pipe.named_steps, dict(expected_steps)) + for name in new_pipe.named_steps: + assert_true(new_pipe.named_steps[name] is pipe.named_steps[name]) + + assert_true(popped_est is pipe.steps[idx][1]) + + # invalid step name + assert_raise_message(ValueError, "'foo' is not in list", + pipe.pop, 'foo') + + # test subtype is maintained by pop + for memory in [None, '/path/to/somewhere']: + pipe = MyPipeline([('transf1', Transf()), + ('predict', Mult())], + memory=memory) + new_pipe, _ = pipe.pop() + assert_equal(type(new_pipe), type(pipe)) + assert_equal(new_pipe.steps, pipe.steps[:-1]) + assert_equal(pipe.memory, new_pipe.memory) + + # test subtype with different constructor signature + pipe = MyPipelineNoMemory([('transf1', Transf()), + ('predict', Mult())], + other_param='blah') + new_pipe, _ = pipe.pop() + assert_equal(new_pipe.steps, pipe.steps[:-1]) + assert_equal(pipe.other_param, new_pipe.other_param)