From 307c360e248e0ff219c7dadad9f23b5107ccc6c3 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Fri, 8 Sep 2017 13:01:23 -0400 Subject: [PATCH 01/15] Don't modify steps in Pipeline.__init__ remove outdated comment fix also for FeatureUnion [MRG+2] Limiting n_components by both n_features and n_samples instead of just n_features (Recreated PR) (#8742) [MRG+1] Remove hard dependency on nose (#9670) MAINT Stop vendoring sphinx-gallery (#9403) CI upgrade travis to run on new numpy release (#9096) CI Make it possible to run doctests in .rst files with pytest (#9697) * doc/datasets/conftest.py to implement the equivalent of nose fixtures * add conftest.py in root folder to ensure that sklearn local folder is used rather than the package in site-packages * test doc with pytest in Travis * move custom_data_home definition from nose fixture to .rst file [MRG+1] avoid integer overflow by using floats for matthews_corrcoef (#9693) * Fix bug#9622: avoid integer overflow by using floats for matthews_corrcoef * matthews_corrcoef: cosmetic change requested by jnothman * Add test_matthews_corrcoef_overflow for Bug#9622 * test_matthews_corrcoef_overflow: clean-up and make deterministic * matthews_corrcoef: pass dtype=np.float64 to sum & trace instead of using astype * test_matthews_corrcoef_overflow: add simple deterministic tests TST Platform independent hash collision tests in FeatureHasher (#9710) TST More informative error message in test_preserve_trustworthiness_approximately (#9738) add some rudimentary tests for meta-estimators fix extra whitespace in error message add missing if_delegate_has_method in pipeline don't test tuple pipeline for now only copy list if not list already? doesn't seem to help? --- sklearn/pipeline.py | 7 ++++-- sklearn/tests/test_metaestimators.py | 33 ++++++++++++++++++++++++++-- sklearn/utils/estimator_checks.py | 2 +- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 93d8db6497b4d..101f83537d47a 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -184,7 +184,8 @@ def _final_estimator(self): def _fit(self, X, y=None, **fit_params): # shallow copy of steps - this should really be steps_ - self.steps = list(self.steps) + if not isinstance(self.steps, list): + self.steps = list(self.steps) self._validate_steps() # Setup the memory memory = check_memory(self.memory) @@ -250,6 +251,7 @@ def fit(self, X, y=None, **fit_params): self._final_estimator.fit(Xt, y, **fit_params) return self + @if_delegate_has_method(delegate='_final_estimator') def fit_transform(self, X, y=None, **fit_params): """Fit the model and transform with the final estimator @@ -708,7 +710,8 @@ def fit(self, X, y=None): self : FeatureUnion This estimator """ - self.transformer_list = list(self.transformer_list) + if not isinstance(self.transformer_list, list): + self.transformer_list = list(self.transformer_list) self._validate_transformers() transformers = Parallel(n_jobs=self.n_jobs)( delayed(_fit_one_transformer)(trans, X, y) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 36885ee8229d8..983ca74024c41 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -10,11 +10,16 @@ from sklearn.utils.testing import assert_true, assert_false, assert_raises from sklearn.utils.validation import check_is_fitted -from sklearn.pipeline import Pipeline +from sklearn.utils.estimator_checks import check_estimator +from sklearn.pipeline import Pipeline, make_pipeline, make_union from sklearn.model_selection import GridSearchCV, RandomizedSearchCV -from sklearn.feature_selection import RFE, RFECV +from sklearn.feature_selection import RFE, RFECV, SelectFromModel from sklearn.ensemble import BaggingClassifier from sklearn.exceptions import NotFittedError +from sklearn.preprocessing import StandardScaler +from sklearn.linear_model import LogisticRegression +from sklearn.decomposition import PCA +from sklearn.cluster import KMeans class DelegatorData(object): @@ -47,6 +52,30 @@ def __init__(self, name, construct, skip_methods=(), ] +def test_metaestimators_check_estimator(): + estimators = [ + # pipeline + # this fails because tuple is converted to list in fit: + # Pipeline((('ss', StandardScaler()),)), + Pipeline([('ss', StandardScaler())]), + make_pipeline(StandardScaler(), LogisticRegression()), + # union + make_union(StandardScaler()), + # union and pipeline + make_pipeline(make_union(PCA(), StandardScaler()), + LogisticRegression()), + # pipeline with clustering + make_pipeline(KMeans()), + # SelectFromModel + make_pipeline(SelectFromModel(LogisticRegression()), + LogisticRegression()), + # grid-search + GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}) + ] + for estimator in estimators: + yield check_estimator, estimator + + def test_metaestimator_delegation(): # Ensures specified metaestimators have methods iff subestimator does def hides(method): diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 3e7cb198a9d12..e4398c5dadc97 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1516,7 +1516,7 @@ def check_estimators_overwrite_params(name, estimator_orig): # fixed the random_state params recursively to be integer seeds. assert_equal(hash(new_value), hash(original_value), "Estimator %s should not change or mutate " - " the parameter %s from %s to %s during fit." + "the parameter %s from %s to %s during fit." % (name, param_name, original_value, new_value)) From 88457604a2469d4c6bd36c2d97f7557a7ad7191c Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 19 Sep 2017 14:16:59 -0400 Subject: [PATCH 02/15] add check with last step None in pipeline. --- sklearn/tests/test_metaestimators.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 983ca74024c41..c8a07eaed4f44 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -72,6 +72,9 @@ def test_metaestimators_check_estimator(): # grid-search GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}) ] + none_pipe = make_pipeline(StandardScaler(), KMeans()) + none_pipe.set_params(kmeans=None) + estimators.append(none_pipe) for estimator in estimators: yield check_estimator, estimator From 40612a43d380f1bbafc56ce965339bbc47cc4e0b Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 12:14:12 -0400 Subject: [PATCH 03/15] use parametrize_with_checks instead of check_estimator --- sklearn/tests/test_metaestimators.py | 51 ++++++++++++++-------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 736942c41197a..1a8934a4ed231 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -8,7 +8,7 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.validation import check_is_fitted -from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.pipeline import Pipeline, make_pipeline, make_union from sklearn.model_selection import GridSearchCV, RandomizedSearchCV from sklearn.feature_selection import RFE, RFECV, SelectFromModel @@ -49,32 +49,31 @@ def __init__(self, name, construct, skip_methods=(), 'predict']) ] +TESTED_META = [ + # pipeline + # this fails because tuple is converted to list in fit: + # Pipeline((('ss', StandardScaler()),)), + Pipeline([('ss', StandardScaler())]), + make_pipeline(StandardScaler(), LogisticRegression()), + # union + make_union(StandardScaler()), + # union and pipeline + make_pipeline(make_union(PCA(), StandardScaler()), + LogisticRegression()), + # pipeline with clustering + make_pipeline(KMeans()), + # SelectFromModel + make_pipeline(SelectFromModel(LogisticRegression()), + LogisticRegression()), + # grid-search + GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}), + make_pipeline(StandardScaler(), None) +] + -def test_metaestimators_check_estimator(): - estimators = [ - # pipeline - # this fails because tuple is converted to list in fit: - # Pipeline((('ss', StandardScaler()),)), - Pipeline([('ss', StandardScaler())]), - make_pipeline(StandardScaler(), LogisticRegression()), - # union - make_union(StandardScaler()), - # union and pipeline - make_pipeline(make_union(PCA(), StandardScaler()), - LogisticRegression()), - # pipeline with clustering - make_pipeline(KMeans()), - # SelectFromModel - make_pipeline(SelectFromModel(LogisticRegression()), - LogisticRegression()), - # grid-search - GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}) - ] - none_pipe = make_pipeline(StandardScaler(), KMeans()) - none_pipe.set_params(kmeans=None) - estimators.append(none_pipe) - for estimator in estimators: - yield check_estimator, estimator +@parametrize_with_checks(TESTED_META) +def test_metaestimators_check_estimator(estimator, check): + check(estimator) def test_metaestimator_delegation(): From 8a20c88abc5cfe5ff7845de380ca7ca2aaf3b812 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 12:33:20 -0400 Subject: [PATCH 04/15] minor fixes, try to hack the tags --- sklearn/pipeline.py | 5 +++++ sklearn/tests/test_metaestimators.py | 8 +++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index bc0937f600883..0297d37570313 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -230,6 +230,11 @@ def __getitem__(self, ind): return self.named_steps[ind] return est + def _more_tags(self): + # hack to make common cases work: + # we assume the pipeline can handle NaN if the first step can? + return {'allow_nan': self.steps[0][1]._get_tags()['allow_nan']} + @property def _estimator_type(self): return self.steps[-1][1]._estimator_type diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 1a8934a4ed231..21245529e4033 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -52,7 +52,8 @@ def __init__(self, name, construct, skip_methods=(), TESTED_META = [ # pipeline # this fails because tuple is converted to list in fit: - # Pipeline((('ss', StandardScaler()),)), + Pipeline((('ss', StandardScaler()),)), + # all pipelines fail because they don't clone: Pipeline([('ss', StandardScaler())]), make_pipeline(StandardScaler(), LogisticRegression()), # union @@ -61,12 +62,13 @@ def __init__(self, name, construct, skip_methods=(), make_pipeline(make_union(PCA(), StandardScaler()), LogisticRegression()), # pipeline with clustering - make_pipeline(KMeans()), + make_pipeline(KMeans(random_state=0)), # SelectFromModel make_pipeline(SelectFromModel(LogisticRegression()), LogisticRegression()), # grid-search - GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}), + GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}, cv=2), + # will fail tragically make_pipeline(StandardScaler(), None) ] From 83c0a8f5e9c8ed561350f3254cf15c9b4ca78e63 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 12:47:29 -0400 Subject: [PATCH 05/15] cheat a bit and skip the tests the pipeline fails b/c of clone --- sklearn/feature_selection/from_model.py | 7 ++++--- sklearn/pipeline.py | 14 ++++++++++---- sklearn/tests/test_metaestimators.py | 6 +++++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 6d732d0e43dfd..f77324a2e9ad3 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -146,9 +146,10 @@ def _get_support_mask(self): elif hasattr(self, 'estimator_'): estimator = self.estimator_ else: - raise ValueError('Either fit the model before transform or set' - ' "prefit=True" while passing the fitted' - ' estimator to the constructor.') + raise NotFittedError( + 'Either fit the model before transform or set' + ' "prefit=True" while passing the fitted' + ' estimator to the constructor.') scores = _get_feature_importances(estimator, self.norm_order) threshold = _calculate_threshold(estimator, scores, self.threshold) if self.max_features is not None: diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 0297d37570313..9cb666a17f2cc 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -232,8 +232,9 @@ def __getitem__(self, ind): def _more_tags(self): # hack to make common cases work: - # we assume the pipeline can handle NaN if the first step can? - return {'allow_nan': self.steps[0][1]._get_tags()['allow_nan']} + # we assume the pipeline can handle NaN if all the steps can? + return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] + for s in self.steps])} @property def _estimator_type(self): @@ -262,8 +263,7 @@ def _log_message(self, step_idx): def _fit(self, X, y=None, **fit_params): # shallow copy of steps - this should really be steps_ - if not isinstance(self.steps, list): - self.steps = list(self.steps) + self.steps = list(self.steps) self._validate_steps() # Setup the memory memory = check_memory(self.memory) @@ -988,6 +988,12 @@ def _update_transformer_list(self, transformers): else next(transformers)) for name, old in self.transformer_list] + def _more_tags(self): + # hack to make common cases work: + # we assume the pipeline can handle NaN if all the steps can? + return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] + for s in self.transformer_list])} + def make_union(*transformers, **kwargs): """Construct a FeatureUnion from the given transformers. diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 21245529e4033..d2a3f5e322459 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -69,12 +69,16 @@ def __init__(self, name, construct, skip_methods=(), # grid-search GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}, cv=2), # will fail tragically - make_pipeline(StandardScaler(), None) + # make_pipeline(StandardScaler(), None) ] @parametrize_with_checks(TESTED_META) def test_metaestimators_check_estimator(estimator, check): + if check.func.__name__ in ["check_estimators_overwrite_params", + "check_dont_overwrite_parameters"]: + # we don't clone in pipeline or feature union + return check(estimator) From 07260c6b81f098378873645decddb0088d653834 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 13:01:00 -0400 Subject: [PATCH 06/15] make it a bit easier --- sklearn/tests/test_metaestimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index d2a3f5e322459..1033025b7c8f2 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -64,7 +64,7 @@ def __init__(self, name, construct, skip_methods=(), # pipeline with clustering make_pipeline(KMeans(random_state=0)), # SelectFromModel - make_pipeline(SelectFromModel(LogisticRegression()), + make_pipeline(SelectFromModel(LogisticRegression(), threshold=-np.inf), LogisticRegression()), # grid-search GridSearchCV(LogisticRegression(), {'C': [0.1, 1]}, cv=2), From 206630c523a6cfef5a20b1d80509d3c288e60429 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 13:04:34 -0400 Subject: [PATCH 07/15] fix comment --- sklearn/pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 9cb666a17f2cc..7b9445288ec72 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -232,7 +232,7 @@ def __getitem__(self, ind): def _more_tags(self): # hack to make common cases work: - # we assume the pipeline can handle NaN if all the steps can? + # we assume the pipeline can handle NaN if all the steps can return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] for s in self.steps])} @@ -990,7 +990,7 @@ def _update_transformer_list(self, transformers): def _more_tags(self): # hack to make common cases work: - # we assume the pipeline can handle NaN if all the steps can? + # we assume the FeatureUnion can handle NaN if all the steps can return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] for s in self.transformer_list])} From 3848f679d911b38049e66298de058486e72644a8 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 13:05:25 -0400 Subject: [PATCH 08/15] tags for feature union are actually correct --- sklearn/pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 7b9445288ec72..aeb38b10cc50f 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -989,8 +989,7 @@ def _update_transformer_list(self, transformers): for name, old in self.transformer_list] def _more_tags(self): - # hack to make common cases work: - # we assume the FeatureUnion can handle NaN if all the steps can + # The FeatureUnion can handle NaNs if all the steps can. return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] for s in self.transformer_list])} From 06e5bbed4bac782f330020d9dd4f50ba70705a8c Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 15:17:28 -0400 Subject: [PATCH 09/15] do proper delegation in pipeline fit_transform --- sklearn/pipeline.py | 10 ++++++++-- sklearn/tests/test_pipeline.py | 30 ++++++++++++++++++------------ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index aeb38b10cc50f..507c6f90721bd 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -356,8 +356,8 @@ def fit(self, X, y=None, **fit_params): self._final_estimator.fit(Xt, y, **fit_params) return self - @if_delegate_has_method(delegate='_final_estimator') - def fit_transform(self, X, y=None, **fit_params): + @property + def fit_transform(self): """Fit the model and transform with the final estimator Fits all the transforms one after the other and transforms the @@ -384,6 +384,12 @@ def fit_transform(self, X, y=None, **fit_params): Xt : array-like, shape = [n_samples, n_transformed_features] Transformed samples """ + # we have a fit_transform whenever we have a transform + self._validate_steps() + self.transform + return self._fit_transform + + def _fit_transform(self, X, y=None, **fit_params): last_step = self._final_estimator Xt, fit_params = self._fit(X, y, **fit_params) with _print_elapsed_time('Pipeline', diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index e02b5ef96b7b0..d3dca7763dc25 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -607,7 +607,10 @@ def test_set_pipeline_steps(): # With invalid data pipeline.set_params(steps=[('junk', ())]) assert_raises(TypeError, pipeline.fit, [[1]], [1]) - assert_raises(TypeError, pipeline.fit_transform, [[1]], [1]) + with pytest.raises(TypeError): + # we can't even find out whether fit_transform + # exists without raising a TypeError + pipeline.fit_transform([[1]], [1]) def test_pipeline_named_steps(): @@ -949,29 +952,32 @@ def test_step_name_validation(): # we validate in construction (despite scikit-learn convention) bad_steps3 = [('a', Mult(2)), (param, Mult(3))] for bad_steps, message in [ - (bad_steps1, "Estimator names must not contain __: got ['a__q']"), - (bad_steps2, "Names provided are not unique: ['a', 'a']"), + (bad_steps1, "Estimator names must not contain __: got \['a__q'\]"), + (bad_steps2, "Names provided are not unique: \['a', 'a'\]"), (bad_steps3, "Estimator names conflict with constructor " - "arguments: ['%s']" % param), + "arguments: \['%s'\]" % param), ]: # three ways to make invalid: # - construction - assert_raise_message(ValueError, message, cls, - **{param: bad_steps}) + with pytest.raises(ValueError, match=message): + cls(**{param: bad_steps}) # - setattr est = cls(**{param: [('a', Mult(1))]}) setattr(est, param, bad_steps) - assert_raise_message(ValueError, message, est.fit, [[1]], [1]) - assert_raise_message(ValueError, message, est.fit_transform, - [[1]], [1]) + with pytest.raises(ValueError, match=message): + est.fit([[1]], [1]) + + with pytest.raises(ValueError, match=message): + est.fit_transform([[1]], [1]) # - set_params est = cls(**{param: [('a', Mult(1))]}) est.set_params(**{param: bad_steps}) - assert_raise_message(ValueError, message, est.fit, [[1]], [1]) - assert_raise_message(ValueError, message, est.fit_transform, - [[1]], [1]) + with pytest.raises(ValueError, match=message): + est.fit([[1]], [1]) + with pytest.raises(ValueError, match=message): + est.fit_transform([[1]], [1]) def test_set_params_nested_pipeline(): From 5e6ee2bbef6482721dcbad832d677fdc4b264827 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 15:30:27 -0400 Subject: [PATCH 10/15] pep8 --- sklearn/tests/test_pipeline.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index d3dca7763dc25..c6a388af0b2c4 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -169,8 +169,8 @@ def test_pipeline_init(): clf = NoTrans() pipe = Pipeline([('svc', clf)]) assert (pipe.get_params(deep=True) == - dict(svc__a=None, svc__b=None, svc=clf, - **pipe.get_params(deep=False))) + dict(svc__a=None, svc__b=None, svc=clf, + **pipe.get_params(deep=False))) # Check that params are set pipe.set_params(svc__a=0.1) @@ -675,15 +675,15 @@ def make(): assert_array_equal([exp], pipeline.fit(X).predict(X)) assert_array_equal(X, pipeline.inverse_transform([[exp]])) assert (pipeline.get_params(deep=True) == - {'steps': pipeline.steps, - 'm2': mult2, - 'm3': passthrough, - 'last': mult5, - 'memory': None, - 'm2__mult': 2, - 'last__mult': 5, - 'verbose': False - }) + {'steps': pipeline.steps, + 'm2': mult2, + 'm3': passthrough, + 'last': mult5, + 'memory': None, + 'm2__mult': 2, + 'last__mult': 5, + 'verbose': False + }) pipeline.set_params(m2=passthrough) exp = 5 @@ -952,10 +952,11 @@ def test_step_name_validation(): # we validate in construction (despite scikit-learn convention) bad_steps3 = [('a', Mult(2)), (param, Mult(3))] for bad_steps, message in [ - (bad_steps1, "Estimator names must not contain __: got \['a__q'\]"), - (bad_steps2, "Names provided are not unique: \['a', 'a'\]"), - (bad_steps3, "Estimator names conflict with constructor " - "arguments: \['%s'\]" % param), + (bad_steps1, r"Estimator names must not contain __:" + r" got \['a__q'\]"), + (bad_steps2, r"Names provided are not unique: \['a', 'a'\]"), + (bad_steps3, r"Estimator names conflict with constructor " + r"arguments: \['%s'\]" % param), ]: # three ways to make invalid: # - construction From 556dde552e7811b74e24f9973647832b027b8bb5 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Mon, 9 Sep 2019 16:25:10 -0400 Subject: [PATCH 11/15] add fit_transform ducktyping tests --- sklearn/pipeline.py | 6 +++++- sklearn/tests/test_pipeline.py | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 507c6f90721bd..7ce336b78e788 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -386,7 +386,11 @@ def fit_transform(self): """ # we have a fit_transform whenever we have a transform self._validate_steps() - self.transform + if self._final_estimator != 'passthrough': + try: + self._final_estimator.transform + except AttributeError: + self._final_estimator.fit_transform return self._fit_transform def _fit_transform(self, X, y=None, **fit_params): diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index c6a388af0b2c4..98ff5a4312afa 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -81,6 +81,11 @@ def inverse_transform(self, X): return X +class FitTransf(NoTrans): + def fit_transform(self, X, y=None): + return X + + class TransfFitParams(Transf): def fit(self, X, y, **fit_params): @@ -728,6 +733,7 @@ def test_pipeline_ducktyping(): pipeline.predict pipeline.transform pipeline.inverse_transform + pipeline.fit_transform pipeline = make_pipeline(Transf()) assert not hasattr(pipeline, 'predict') @@ -739,6 +745,7 @@ def test_pipeline_ducktyping(): assert not hasattr(pipeline, 'predict') pipeline.transform pipeline.inverse_transform + pipeline.fit_transform pipeline = make_pipeline(Transf(), NoInvTransf()) assert not hasattr(pipeline, 'predict') @@ -750,6 +757,10 @@ def test_pipeline_ducktyping(): pipeline.transform assert not hasattr(pipeline, 'inverse_transform') + pipeline = make_pipeline(FitTransf()) + assert not hasattr(pipeline, 'transform') + pipeline.fit_transform + def test_make_pipeline(): t1 = Transf() From a2a1e32d55ebab7252e061e8041d9039a38601b2 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 24 Sep 2019 14:39:56 -0400 Subject: [PATCH 12/15] address some comments --- sklearn/pipeline.py | 18 +++++++++--------- sklearn/tests/test_metaestimators.py | 12 ++++++------ sklearn/tests/test_pipeline.py | 3 ++- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 7ce336b78e788..d455bdc6b22fc 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -233,8 +233,8 @@ def __getitem__(self, ind): def _more_tags(self): # hack to make common cases work: # we assume the pipeline can handle NaN if all the steps can - return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] - for s in self.steps])} + return {'allow_nan': np.all(s[1]._get_tags()['allow_nan'] + for s in self.steps)} @property def _estimator_type(self): @@ -384,13 +384,13 @@ def fit_transform(self): Xt : array-like, shape = [n_samples, n_transformed_features] Transformed samples """ - # we have a fit_transform whenever we have a transform self._validate_steps() + # pipeline has a fit_transform whenever the final estimator has + # transform or fit_transform if self._final_estimator != 'passthrough': - try: - self._final_estimator.transform - except AttributeError: - self._final_estimator.fit_transform + if (not hasattr(self._final_estimator, 'transform') + and not hasattr(self.final_estimator, 'fit_transform')): + raise AttributeError return self._fit_transform def _fit_transform(self, X, y=None, **fit_params): @@ -1000,8 +1000,8 @@ def _update_transformer_list(self, transformers): def _more_tags(self): # The FeatureUnion can handle NaNs if all the steps can. - return {'allow_nan': np.all([s[1]._get_tags()['allow_nan'] - for s in self.transformer_list])} + return {'allow_nan': np.all(s[1]._get_tags()['allow_nan'] + for s in self.transformer_list)} def make_union(*transformers, **kwargs): diff --git a/sklearn/tests/test_metaestimators.py b/sklearn/tests/test_metaestimators.py index 1033025b7c8f2..03a96bc03da05 100644 --- a/sklearn/tests/test_metaestimators.py +++ b/sklearn/tests/test_metaestimators.py @@ -9,7 +9,7 @@ from sklearn.utils.testing import assert_raises from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import parametrize_with_checks -from sklearn.pipeline import Pipeline, make_pipeline, make_union +from sklearn.pipeline import Pipeline, make_pipeline, make_union, FeatureUnion from sklearn.model_selection import GridSearchCV, RandomizedSearchCV from sklearn.feature_selection import RFE, RFECV, SelectFromModel from sklearn.ensemble import BaggingClassifier @@ -50,10 +50,8 @@ def __init__(self, name, construct, skip_methods=(), ] TESTED_META = [ - # pipeline - # this fails because tuple is converted to list in fit: + # pipelines Pipeline((('ss', StandardScaler()),)), - # all pipelines fail because they don't clone: Pipeline([('ss', StandardScaler())]), make_pipeline(StandardScaler(), LogisticRegression()), # union @@ -75,8 +73,10 @@ def __init__(self, name, construct, skip_methods=(), @parametrize_with_checks(TESTED_META) def test_metaestimators_check_estimator(estimator, check): - if check.func.__name__ in ["check_estimators_overwrite_params", - "check_dont_overwrite_parameters"]: + if (check.func.__name__ in ["check_estimators_overwrite_params", + "check_dont_overwrite_parameters"] + and (isinstance(estimator, Pipeline) + or isinstance(estimator, FeatureUnion))): # we don't clone in pipeline or feature union return check(estimator) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 98ff5a4312afa..7b65625f26658 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -82,6 +82,7 @@ def inverse_transform(self, X): class FitTransf(NoTrans): + # has fit_transform but not transform def fit_transform(self, X, y=None): return X @@ -964,7 +965,7 @@ def test_step_name_validation(): bad_steps3 = [('a', Mult(2)), (param, Mult(3))] for bad_steps, message in [ (bad_steps1, r"Estimator names must not contain __:" - r" got \['a__q'\]"), + r" got \['a__q'\]"), (bad_steps2, r"Names provided are not unique: \['a', 'a'\]"), (bad_steps3, r"Estimator names conflict with constructor " r"arguments: \['%s'\]" % param), From 94a0e5ebc884c010d8bdb81a341cc7407dcf89d7 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 24 Sep 2019 15:31:41 -0400 Subject: [PATCH 13/15] fix typo --- sklearn/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index d455bdc6b22fc..47e02964c880c 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -389,7 +389,7 @@ def fit_transform(self): # transform or fit_transform if self._final_estimator != 'passthrough': if (not hasattr(self._final_estimator, 'transform') - and not hasattr(self.final_estimator, 'fit_transform')): + and not hasattr(self._final_estimator, 'fit_transform')): raise AttributeError return self._fit_transform From 2625b58d17125b63c0f3c13c021260b65a10e6f7 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 24 Sep 2019 15:49:19 -0400 Subject: [PATCH 14/15] use standard all on generator expression --- sklearn/pipeline.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 47e02964c880c..46d713a272fac 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -233,8 +233,8 @@ def __getitem__(self, ind): def _more_tags(self): # hack to make common cases work: # we assume the pipeline can handle NaN if all the steps can - return {'allow_nan': np.all(s[1]._get_tags()['allow_nan'] - for s in self.steps)} + return {'allow_nan': all(s[1]._get_tags()['allow_nan'] for s in + self.steps)} @property def _estimator_type(self): @@ -1000,8 +1000,8 @@ def _update_transformer_list(self, transformers): def _more_tags(self): # The FeatureUnion can handle NaNs if all the steps can. - return {'allow_nan': np.all(s[1]._get_tags()['allow_nan'] - for s in self.transformer_list)} + return {'allow_nan': all(s[1]._get_tags()['allow_nan'] for s in + self.transformer_list)} def make_union(*transformers, **kwargs): From 646e0d8abb0e83e9754b973130b655d52ac217f8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 31 Jul 2020 10:51:17 +0200 Subject: [PATCH 15/15] changes during merge --- sklearn/feature_selection/_from_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index b73ce80d32631..69539bf235d02 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -173,9 +173,9 @@ def _get_support_mask(self): elif hasattr(self, 'estimator_'): estimator = self.estimator_ else: - raise ValueError('Either fit the model before transform or set' - ' "prefit=True" while passing the fitted' - ' estimator to the constructor.') + raise NotFittedError('Either fit the model before transform or set' + ' "prefit=True" while passing the fitted' + ' estimator to the constructor.') scores = _get_feature_importances( estimator=estimator, getter=self.importance_getter, transform_func='norm', norm_order=self.norm_order)