diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst index 5ce5386343666..a48164b09470e 100644 --- a/doc/modules/pipeline.rst +++ b/doc/modules/pipeline.rst @@ -39,13 +39,10 @@ is an estimator object:: >>> from sklearn.decomposition import PCA >>> estimators = [('reduce_dim', PCA()), ('clf', SVC())] >>> pipe = Pipeline(estimators) - >>> pipe # doctest: +NORMALIZE_WHITESPACE - Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power='auto', - n_components=None, random_state=None, svd_solver='auto', tol=0.0, - whiten=False)), ('clf', SVC(C=1.0, cache_size=200, class_weight=None, - coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto', - kernel='rbf', max_iter=-1, probability=False, random_state=None, - shrinking=True, tol=0.001, verbose=False))]) + >>> pipe # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Pipeline(memory=None, + steps=[('reduce_dim', PCA(copy=True,...)), + ('clf', SVC(C=1.0,...))]) The utility function :func:`make_pipeline` is a shorthand for constructing pipelines; @@ -56,7 +53,8 @@ filling in the names automatically:: >>> from sklearn.naive_bayes import MultinomialNB >>> from sklearn.preprocessing import Binarizer >>> make_pipeline(Binarizer(), MultinomialNB()) # doctest: +NORMALIZE_WHITESPACE - Pipeline(steps=[('binarizer', Binarizer(copy=True, threshold=0.0)), + Pipeline(memory=None, + steps=[('binarizer', Binarizer(copy=True, threshold=0.0)), ('multinomialnb', MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True))]) @@ -76,30 +74,26 @@ and as a ``dict`` in ``named_steps``:: Parameters of the estimators in the pipeline can be accessed using the ``__`` syntax:: - >>> pipe.set_params(clf__C=10) # doctest: +NORMALIZE_WHITESPACE - Pipeline(steps=[('reduce_dim', PCA(copy=True, iterated_power='auto', - n_components=None, random_state=None, svd_solver='auto', tol=0.0, - whiten=False)), ('clf', SVC(C=10, cache_size=200, class_weight=None, - coef0=0.0, decision_function_shape='ovr', degree=3, gamma='auto', - kernel='rbf', max_iter=-1, probability=False, random_state=None, - shrinking=True, tol=0.001, verbose=False))]) - + >>> pipe.set_params(clf__C=10) # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Pipeline(memory=None, + steps=[('reduce_dim', PCA(copy=True, iterated_power='auto',...)), + ('clf', SVC(C=10, cache_size=200, class_weight=None,...))]) This is particularly important for doing grid searches:: >>> from sklearn.model_selection import GridSearchCV - >>> params = dict(reduce_dim__n_components=[2, 5, 10], - ... clf__C=[0.1, 10, 100]) - >>> grid_search = GridSearchCV(pipe, param_grid=params) + >>> param_grid = dict(reduce_dim__n_components=[2, 5, 10], + ... clf__C=[0.1, 10, 100]) + >>> grid_search = GridSearchCV(pipe, param_grid=param_grid) Individual steps may also be replaced as parameters, and non-final steps may be ignored by setting them to ``None``:: >>> from sklearn.linear_model import LogisticRegression - >>> params = dict(reduce_dim=[None, PCA(5), PCA(10)], - ... clf=[SVC(), LogisticRegression()], - ... clf__C=[0.1, 10, 100]) - >>> grid_search = GridSearchCV(pipe, param_grid=params) + >>> param_grid = dict(reduce_dim=[None, PCA(5), PCA(10)], + ... clf=[SVC(), LogisticRegression()], + ... clf__C=[0.1, 10, 100]) + >>> grid_search = GridSearchCV(pipe, param_grid=param_grid) .. topic:: Examples: @@ -108,6 +102,7 @@ ignored by setting them to ``None``:: * :ref:`sphx_glr_auto_examples_plot_digits_pipe.py` * :ref:`sphx_glr_auto_examples_plot_kernel_approximation.py` * :ref:`sphx_glr_auto_examples_svm_plot_svm_anova.py` + * :ref:`sphx_glr_auto_examples_plot_compare_reduction.py` .. topic:: See also: @@ -124,6 +119,84 @@ i.e. if the last estimator is a classifier, the :class:`Pipeline` can be used as a classifier. If the last estimator is a transformer, again, so is the pipeline. +Caching transformers: avoid repeated computation +------------------------------------------------- + +.. currentmodule:: sklearn.pipeline + +Fitting transformers may be computationally expensive. With its +``memory`` parameter set, :class:`Pipeline` will cache each transformer +after calling ``fit``. +This feature is used to avoid computing the fit transformers within a pipeline +if the parameters and input data are identical. A typical example is the case of +a grid search in which the transformers can be fitted only once and reused for +each configuration. + +The parameter ``memory`` is needed in order to cache the transformers. +``memory`` can be either a string containing the directory where to cache the +transformers or a `joblib.Memory `_ +object:: + + >>> from tempfile import mkdtemp + >>> from shutil import rmtree + >>> from sklearn.decomposition import PCA + >>> from sklearn.svm import SVC + >>> from sklearn.pipeline import Pipeline + >>> estimators = [('reduce_dim', PCA()), ('clf', SVC())] + >>> cachedir = mkdtemp() + >>> pipe = Pipeline(estimators, memory=cachedir) + >>> pipe # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Pipeline(..., + steps=[('reduce_dim', PCA(copy=True,...)), + ('clf', SVC(C=1.0,...))]) + >>> # Clear the cache directory when you don't need it anymore + >>> rmtree(cachedir) + +.. warning:: **Side effect of caching transfomers** + + Using a :class:`Pipeline` without cache enabled, it is possible to + inspect the original instance such as:: + + >>> from sklearn.datasets import load_digits + >>> digits = load_digits() + >>> pca1 = PCA() + >>> svm1 = SVC() + >>> pipe = Pipeline([('reduce_dim', pca1), ('clf', svm1)]) + >>> pipe.fit(digits.data, digits.target) + ... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Pipeline(memory=None, + steps=[('reduce_dim', PCA(...)), ('clf', SVC(...))]) + >>> # The pca instance can be inspected directly + >>> print(pca1.components_) # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + [[ -1.77484909e-19 ... 4.07058917e-18]] + + Enabling caching triggers a clone of the transformers before fitting. + Therefore, the transformer instance given to the pipeline cannot be + inspected directly. + In following example, accessing the :class:`PCA` instance ``pca2`` + will raise an ``AttributeError`` since ``pca2`` will be an unfitted + transformer. + Instead, use the attribute ``named_steps`` to inspect estimators within + the pipeline:: + + >>> cachedir = mkdtemp() + >>> pca2 = PCA() + >>> svm2 = SVC() + >>> cached_pipe = Pipeline([('reduce_dim', pca2), ('clf', svm2)], + ... memory=cachedir) + >>> cached_pipe.fit(digits.data, digits.target) + ... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + Pipeline(memory=..., + steps=[('reduce_dim', PCA(...)), ('clf', SVC(...))]) + >>> print(cached_pipe.named_steps['reduce_dim'].components_) + ... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + [[ -1.77484909e-19 ... 4.07058917e-18]] + >>> # Remove the cache directory + >>> rmtree(cachedir) + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_plot_compare_reduction.py` .. _feature_union: @@ -164,15 +237,11 @@ and ``value`` is an estimator object:: >>> from sklearn.decomposition import KernelPCA >>> estimators = [('linear_pca', PCA()), ('kernel_pca', KernelPCA())] >>> combined = FeatureUnion(estimators) - >>> combined # doctest: +NORMALIZE_WHITESPACE - FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True, - iterated_power='auto', n_components=None, random_state=None, - svd_solver='auto', tol=0.0, whiten=False)), ('kernel_pca', - KernelPCA(alpha=1.0, coef0=1, copy_X=True, degree=3, - eigen_solver='auto', fit_inverse_transform=False, gamma=None, - kernel='linear', kernel_params=None, max_iter=None, n_components=None, - n_jobs=1, random_state=None, remove_zero_eig=False, tol=0))], - transformer_weights=None) + >>> combined # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + FeatureUnion(n_jobs=1, + transformer_list=[('linear_pca', PCA(copy=True,...)), + ('kernel_pca', KernelPCA(alpha=1.0,...))], + transformer_weights=None) Like pipelines, feature unions have a shorthand constructor called @@ -182,11 +251,12 @@ Like pipelines, feature unions have a shorthand constructor called Like ``Pipeline``, individual steps may be replaced using ``set_params``, and ignored by setting to ``None``:: - >>> combined.set_params(kernel_pca=None) # doctest: +NORMALIZE_WHITESPACE - FeatureUnion(n_jobs=1, transformer_list=[('linear_pca', PCA(copy=True, - iterated_power='auto', n_components=None, random_state=None, - svd_solver='auto', tol=0.0, whiten=False)), ('kernel_pca', None)], - transformer_weights=None) + >>> combined.set_params(kernel_pca=None) + ... # doctest: +NORMALIZE_WHITESPACE, +ELLIPSIS + FeatureUnion(n_jobs=1, + transformer_list=[('linear_pca', PCA(copy=True,...)), + ('kernel_pca', None)], + transformer_weights=None) .. topic:: Examples: diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 9d76f5377c0e1..4afe8a5dc52fb 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -56,6 +56,9 @@ Enhancements - :class:`multioutput.MultiOutputRegressor` and :class:`multioutput.MultiOutputClassifier` now support online learning using `partial_fit`. issue: `8053` by :user:`Peng Yu `. + - :class:`pipeline.Pipeline` allows to cache transformers + within a pipeline by using the ``memory`` constructor parameter. + By :issue:`7990` by :user:`Guillaume Lemaitre `. - :class:`decomposition.PCA`, :class:`decomposition.IncrementalPCA` and :class:`decomposition.TruncatedSVD` now expose the singular values diff --git a/examples/plot_compare_reduction.py b/examples/plot_compare_reduction.py old mode 100644 new mode 100755 index 1c84ea9c3a4dc..05ea0168a5906 --- a/examples/plot_compare_reduction.py +++ b/examples/plot_compare_reduction.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/env python # -*- coding: utf-8 -*- """ ================================================================= @@ -7,13 +7,27 @@ This example constructs a pipeline that does dimensionality reduction followed by prediction with a support vector -classifier. It demonstrates the use of GridSearchCV and -Pipeline to optimize over different classes of estimators in a -single CV run -- unsupervised PCA and NMF dimensionality +classifier. It demonstrates the use of ``GridSearchCV`` and +``Pipeline`` to optimize over different classes of estimators in a +single CV run -- unsupervised ``PCA`` and ``NMF`` dimensionality reductions are compared to univariate feature selection during the grid search. + +Additionally, ``Pipeline`` can be instantiated with the ``memory`` +argument to memoize the transformers within the pipeline, avoiding to fit +again the same transformers over and over. + +Note that the use of ``memory`` to enable caching becomes interesting when the +fitting of a transformer is costly. """ -# Authors: Robert McGibbon, Joel Nothman + +############################################################################### +# Illustration of ``Pipeline`` and ``GridSearchCV`` +############################################################################### +# This section illustrates the use of a ``Pipeline`` with +# ``GridSearchCV`` + +# Authors: Robert McGibbon, Joel Nothman, Guillaume Lemaitre from __future__ import print_function, division @@ -49,7 +63,7 @@ ] reducer_labels = ['PCA', 'NMF', 'KBest(chi2)'] -grid = GridSearchCV(pipe, cv=3, n_jobs=2, param_grid=param_grid) +grid = GridSearchCV(pipe, cv=3, n_jobs=1, param_grid=param_grid) digits = load_digits() grid.fit(digits.data, digits.target) @@ -72,4 +86,45 @@ plt.ylabel('Digit classification accuracy') plt.ylim((0, 1)) plt.legend(loc='upper left') + +############################################################################### +# Caching transformers within a ``Pipeline`` +############################################################################### +# It is sometimes worthwhile storing the state of a specific transformer +# since it could be used again. Using a pipeline in ``GridSearchCV`` triggers +# such situations. Therefore, we use the argument ``memory`` to enable caching. +# +# .. warning:: +# Note that this example is, however, only an illustration since for this +# specific case fitting PCA is not necessarily slower than loading the +# cache. Hence, use the ``memory`` constructor parameter when the fitting +# of a transformer is costly. + +from tempfile import mkdtemp +from shutil import rmtree +from sklearn.externals.joblib import Memory + +# Create a temporary folder to store the transformers of the pipeline +cachedir = mkdtemp() +memory = Memory(cachedir=cachedir, verbose=10) +cached_pipe = Pipeline([('reduce_dim', PCA()), + ('classify', LinearSVC())], + memory=memory) + +# This time, a cached pipeline will be used within the grid search +grid = GridSearchCV(cached_pipe, cv=3, n_jobs=1, param_grid=param_grid) +digits = load_digits() +grid.fit(digits.data, digits.target) + +# Delete the temporary cache before exiting +rmtree(cachedir) + +############################################################################### +# The ``PCA`` fitting is only computed at the evaluation of the first +# configuration of the ``C`` parameter of the ``LinearSVC`` classifier. The +# other configurations of ``C`` will trigger the loading of the cached ``PCA`` +# estimator data, leading to save processing time. Therefore, the use of +# caching the pipeline using ``memory`` is highly beneficial when fitting +# a transformer is costly. + plt.show() diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 784fad75b77ac..61d7b12b7564d 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -15,8 +15,8 @@ import numpy as np from scipy import sparse -from .base import BaseEstimator, TransformerMixin -from .externals.joblib import Parallel, delayed +from .base import clone, BaseEstimator, TransformerMixin +from .externals.joblib import Parallel, delayed, Memory from .externals import six from .utils import tosequence from .utils.metaestimators import if_delegate_has_method @@ -89,6 +89,7 @@ class Pipeline(_BasePipeline): Intermediate steps of the pipeline must be 'transforms', that is, they must implement fit and transform methods. The final estimator only needs to implement fit. + The transformers in the pipeline can be cached using ```memory`` argument. The purpose of the pipeline is to assemble several steps that can be cross-validated together while setting different parameters. @@ -107,6 +108,18 @@ class Pipeline(_BasePipeline): chained, in the order in which they are chained, with the last object an estimator. + memory : Instance of joblib.Memory or string, optional (default=None) + Used to caching the fitted transformers of the transformer of the + pipeline. By default, no cache is performed. + If a string is given, it is the path to the caching directory. + Enabling caching triggers a clone of the transformers before fitting. + Therefore, the transformer instance given to the pipeline cannot be + inspected directly. Use the attribute ``named_steps`` or ``steps`` + to inspect estimators within the pipeline. + Caching the transformers is advantageous when fitting is time + consuming. + + Attributes ---------- named_steps : dict @@ -131,8 +144,10 @@ class Pipeline(_BasePipeline): >>> # For instance, fit using a k of 10 in the SelectKBest >>> # and a parameter 'C' of the svm >>> anova_svm.set_params(anova__k=10, svc__C=.1).fit(X, y) - ... # doctest: +ELLIPSIS - Pipeline(steps=[...]) + ... # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE + Pipeline(memory=None, + steps=[('anova', SelectKBest(...)), + ('svc', SVC(...))]) >>> prediction = anova_svm.predict(X) >>> anova_svm.score(X, y) # doctest: +ELLIPSIS 0.829... @@ -142,14 +157,16 @@ class Pipeline(_BasePipeline): array([False, False, True, True, False, False, True, True, False, True, False, True, True, False, True, False, True, True, False, False], dtype=bool) + """ # BaseEstimator interface - def __init__(self, steps): + def __init__(self, steps, memory=None): # shallow copy of steps self.steps = tosequence(steps) self._validate_steps() + self.memory = memory def get_params(self, deep=True): """Get parameters for this estimator. @@ -220,20 +237,43 @@ def _final_estimator(self): def _fit(self, X, y=None, **fit_params): self._validate_steps() + # Setup the memory + memory = self.memory + if memory is None: + memory = Memory(cachedir=None, verbose=0) + elif isinstance(memory, six.string_types): + memory = Memory(cachedir=memory, verbose=0) + elif not isinstance(memory, Memory): + raise ValueError("'memory' should either be a string or" + " a joblib.Memory instance, got" + " 'memory={!r}' instead.".format(memory)) + + fit_transform_one_cached = memory.cache(_fit_transform_one) + fit_params_steps = dict((name, {}) for name, step in self.steps if step is not None) for pname, pval in six.iteritems(fit_params): step, param = pname.split('__', 1) fit_params_steps[step][param] = pval Xt = X - for name, transform in self.steps[:-1]: - if transform is None: + for step_idx, (name, transformer) in enumerate(self.steps[:-1]): + if transformer is None: pass - elif hasattr(transform, "fit_transform"): - Xt = transform.fit_transform(Xt, y, **fit_params_steps[name]) else: - Xt = transform.fit(Xt, y, **fit_params_steps[name]) \ - .transform(Xt) + if memory.cachedir is None: + # we do not clone when caching is disabled to preserve + # backward compatibility + cloned_transformer = transformer + else: + cloned_transformer = clone(transformer) + # Fit or load from cache the current transfomer + Xt, fitted_transformer = fit_transform_one_cached( + cloned_transformer, None, Xt, y, + **fit_params_steps[name]) + # Replace the transformer of the step with the fitted + # transformer. This is necessary when loading the transformer + # from the cache. + self.steps[step_idx] = (name, fitted_transformer) if self._final_estimator is None: return Xt, {} return Xt, fit_params_steps[self.steps[-1][0]] @@ -550,7 +590,8 @@ def make_pipeline(*steps): >>> from sklearn.preprocessing import StandardScaler >>> make_pipeline(StandardScaler(), GaussianNB(priors=None)) ... # doctest: +NORMALIZE_WHITESPACE - Pipeline(steps=[('standardscaler', + Pipeline(memory=None, + steps=[('standardscaler', StandardScaler(copy=True, with_mean=True, with_std=True)), ('gaussiannb', GaussianNB(priors=None))]) @@ -565,7 +606,7 @@ def _fit_one_transformer(transformer, X, y): return transformer.fit(X, y) -def _transform_one(transformer, name, weight, X): +def _transform_one(transformer, weight, X): res = transformer.transform(X) # if we have a weight for this transformer, multiply output if weight is None: @@ -573,7 +614,7 @@ def _transform_one(transformer, name, weight, X): return res * weight -def _fit_transform_one(transformer, name, weight, X, y, +def _fit_transform_one(transformer, weight, X, y, **fit_params): if hasattr(transformer, 'fit_transform'): res = transformer.fit_transform(X, y, **fit_params) @@ -731,7 +772,7 @@ def fit_transform(self, X, y=None, **fit_params): """ self._validate_transformers() result = Parallel(n_jobs=self.n_jobs)( - delayed(_fit_transform_one)(trans, name, weight, X, y, + delayed(_fit_transform_one)(trans, weight, X, y, **fit_params) for name, trans, weight in self._iter()) @@ -761,7 +802,7 @@ def transform(self, X): sum of n_components (output dimension) over transformers. """ Xs = Parallel(n_jobs=self.n_jobs)( - delayed(_transform_one)(trans, name, weight, X) + delayed(_transform_one)(trans, weight, X) for name, trans, weight in self._iter()) if not Xs: # All transformers are None diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index ade5ed3b27e41..33e3128931aff 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -1,6 +1,11 @@ """ Test the pipeline module. """ + +from tempfile import mkdtemp +import shutil +import time + import numpy as np from scipy import sparse @@ -26,6 +31,7 @@ from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer +from sklearn.externals.joblib import Memory JUNK_FOOD_DOCS = ( @@ -125,6 +131,17 @@ def score(self, X, y=None, sample_weight=None): return np.sum(X) +class DummyTransf(Transf): + """Transformer which store the column means""" + + def fit(self, X, y): + self.means_ = np.mean(X, axis=0) + # store timestamp to figure out whether the result of 'fit' has been + # cached or not + self.timestamp_ = time.time() + return self + + def test_pipeline_init(): # Test the various init parameters of the pipeline. assert_raises(TypeError, Pipeline) @@ -520,6 +537,7 @@ def make(): 'm2': mult2, 'm3': None, 'last': mult5, + 'memory': None, 'm2__mult': 2, 'last__mult': 5, }) @@ -799,3 +817,80 @@ def test_step_name_validation(): assert_raise_message(ValueError, message, est.fit, [[1]], [1]) assert_raise_message(ValueError, message, est.fit_transform, [[1]], [1]) + + +def test_pipeline_wrong_memory(): + # Test that an error is raised when memory is not a string or a Memory + # instance + iris = load_iris() + X = iris.data + y = iris.target + # Define memory as an integer + memory = 1 + cached_pipe = Pipeline([('transf', DummyTransf()), ('svc', SVC())], + memory=memory) + assert_raises_regex(ValueError, "'memory' should either be a string or a" + " joblib.Memory instance, got 'memory=1' instead.", + cached_pipe.fit, X, y) + + +def test_pipeline_memory(): + iris = load_iris() + X = iris.data + y = iris.target + cachedir = mkdtemp() + try: + memory = Memory(cachedir=cachedir, verbose=10) + # Test with Transformer + SVC + clf = SVC(probability=True, random_state=0) + transf = DummyTransf() + pipe = Pipeline([('transf', clone(transf)), ('svc', clf)]) + cached_pipe = Pipeline([('transf', transf), ('svc', clf)], + memory=memory) + + # Memoize the transformer at the first fit + cached_pipe.fit(X, y) + pipe.fit(X, y) + # Get the time stamp of the tranformer in the cached pipeline + ts = cached_pipe.named_steps['transf'].timestamp_ + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe.predict(X)) + assert_array_equal(pipe.predict_proba(X), cached_pipe.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe.named_steps['transf'].means_) + assert_false(hasattr(transf, 'means_')) + # Check that we are reading the cache while fitting + # a second time + cached_pipe.fit(X, y) + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe.predict(X)) + assert_array_equal(pipe.predict_proba(X), cached_pipe.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe.named_steps['transf'].means_) + assert_equal(ts, cached_pipe.named_steps['transf'].timestamp_) + # Create a new pipeline with cloned estimators + # Check that even changing the name step does not affect the cache hit + clf_2 = SVC(probability=True, random_state=0) + transf_2 = DummyTransf() + cached_pipe_2 = Pipeline([('transf_2', transf_2), ('svc', clf_2)], + memory=memory) + cached_pipe_2.fit(X, y) + + # Check that cached_pipe and pipe yield identical results + assert_array_equal(pipe.predict(X), cached_pipe_2.predict(X)) + assert_array_equal(pipe.predict_proba(X), + cached_pipe_2.predict_proba(X)) + assert_array_equal(pipe.predict_log_proba(X), + cached_pipe_2.predict_log_proba(X)) + assert_array_equal(pipe.score(X, y), cached_pipe_2.score(X, y)) + assert_array_equal(pipe.named_steps['transf'].means_, + cached_pipe_2.named_steps['transf_2'].means_) + assert_equal(ts, cached_pipe_2.named_steps['transf_2'].timestamp_) + finally: + shutil.rmtree(cachedir)