From 608d3ad057cc5acb4ad43e9e75d3c7e877dfd40b Mon Sep 17 00:00:00 2001 From: Mathieu Blondel Date: Thu, 14 Nov 2013 18:23:09 +0900 Subject: [PATCH] Add syntactic sugar for pipelines. --- sklearn/base.py | 12 +++++++++++- sklearn/tests/test_pipeline.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index df85044e972b9..4f7a1d749513a 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -262,13 +262,23 @@ def __repr__(self): class_name = self.__class__.__name__ return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False), offset=len(class_name),),) - def __str__(self): class_name = self.__class__.__name__ return '%s(%s)' % (class_name, _pprint(self.get_params(deep=True), offset=len(class_name), printer=str,),) + def __or__(self, estimator): + from sklearn.pipeline import Pipeline + if isinstance(self, Pipeline): + name = estimator.__class__.__name__ + steps = self.steps + [(name, estimator)] + else: + name1 = self.__class__.__name__ + name2 = estimator.__class__.__name__ + steps = [(name1, self), (name2, estimator)] + return Pipeline(steps) + ############################################################################### class ClassifierMixin(object): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index c3b2fbc170b70..295e2a01395d0 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -195,6 +195,26 @@ def test_pipeline_methods_preprocessing_svm(): pipe.score(X, y) +def test_pipeline_syntactic_sugar(): + iris = load_iris() + X = iris.data + y = iris.target + pipeline = StandardScaler() | PCA(n_components=2) | SVC(random_state=0) + assert_true(isinstance(pipeline, Pipeline)) + + assert_equal(pipeline.steps[0][0], "StandardScaler") + assert_true(isinstance(pipeline.steps[0][1], StandardScaler)) + + assert_equal(pipeline.steps[1][0], "PCA") + assert_true(isinstance(pipeline.steps[1][1], PCA)) + + assert_equal(pipeline.steps[2][0], "SVC") + assert_true(isinstance(pipeline.steps[2][1], SVC)) + + y_pred = pipeline.fit(X, y).predict(X) + assert_equal(np.mean(y == y_pred), 0.92) + + def test_feature_union(): # basic sanity check for feature union iris = load_iris()