diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst index 4356b3fe8d640..d6f5c8645b282 100644 --- a/doc/modules/pipeline.rst +++ b/doc/modules/pipeline.rst @@ -269,3 +269,59 @@ and ignored by setting to ``None``:: * :ref:`sphx_glr_auto_examples_plot_feature_stacker.py` * :ref:`sphx_glr_auto_examples_hetero_feature_union.py` + +.. _frozen: + +Frozen estimators +================= +.. currentmodule:: sklearn + +It can be useful to pre-fit an estimator before including it in a Pipeline, +FeatureUnion or other meta-estimators. Example applications include: + +* transfer learning: incorporating a transformer trained on a large unlabelled + dataset in a prediction pipeline where the data to be modelled is much smaller +* feature selection on the basis of an already fitted predictive model +* calibrating an already fitted classifier for probabilistic output + +To enable this, some meta-estimators will refrain from fitting an estimator +where it has the attribute ``frozen`` set to ``True``. For example:: + + Without transfer learning + + >>> from sklearn.datasets import load_... + >>> from sklearn.model_selection import cross_val_score + >>> cross_val_score(make_pipeline(TfidfVectorizer(), LogisticRegression()), + ... X, y) + + With transfer learning: + >>> tfidf = TfidfVectorizer().fit(large_X) + >>> tfidf.frozen = True + >>> cross_val_score(make_pipeline(tfidf, LogisticRegression()), + ... X, y) + +The following meta-estimators may make use of frozen estimators: + +* :class:`pipeline.Pipeline` +* :class:`pipeline.FeatureUnion` +* :class:`ensemble.VotingClassifier` +* :class:`feature_selection.SelectFromModel` +* :class:`calibration.CalibratedClassifierCV` with ``cv='prefit'`` + +:func:`base.frozen_fit` is also available for developers of meta-estimators. + +.. note:: + When an estimator is frozen, calling :func:`clone` on it will return + itself.:: + + >>> from base import clone + >>> clone(tfidf) is tfidf + True + + This allows the model to be left untouched in cross-validation and + meta-estimators which clear the estimator with ``clone``. + +.. warning:: Leakage: + Please take care to not introduce data leakage by this method: do not + incorporate your test set into the training of some frozen component, + unless it would be realistic to do so in the target application. diff --git a/sklearn/base.py b/sklearn/base.py index aa4f9f9ce17c1..30c8947b23f83 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -44,6 +44,8 @@ def clone(estimator, safe=True): """ estimator_type = type(estimator) + if getattr(estimator, 'frozen', False): + return estimator # XXX: not handling dictionaries if estimator_type in (list, tuple, set, frozenset): return estimator_type([clone(e, safe=safe) for e in estimator]) @@ -578,3 +580,39 @@ def is_regressor(estimator): True if estimator is a regressor and False otherwise. """ return getattr(estimator, "_estimator_type", None) == "regressor" + + +def frozen_fit(estimator, method, X, y, **kwargs): + """Fit the estimator if not frozen, and return the result of method + + A frozen estimator has an attribute ``frozen`` set to True + + Parameters + ---------- + estimator + method : str + One of {'fit', 'fit_transform', 'fit_predict'} or similar. + X + y + will only be passed when fitting + kwargs + will only be passed when fitting + + Returns + ------- + out + estimator if ``method == 'fit'``, else the output of ``transform`` etc. + If the estimator has attribute ``frozen`` set to True, it will not be + refit. + """ + if getattr(estimator, 'frozen', False): + if method == 'fit': + return estimator + if not method.startswith('fit_'): + raise ValueError('method must be "fit" or begin with "fit_"') + method = getattr(estimator, method[4:]) + # FIXME: what do we do with kwargs? + return method(X) + else: + method = getattr(estimator, method) + return method(X, y, **kwargs) diff --git a/sklearn/calibration.py b/sklearn/calibration.py index 0d2f76cd12239..7a21ec3dbfe74 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -45,7 +45,8 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin): base_estimator : instance BaseEstimator The classifier whose output decision function needs to be calibrated to offer more accurate predict_proba outputs. If cv=prefit, the - classifier must have been fit already on data. + classifier must have been fit already on data, and it is recommended + that the classifier be frozen (see :ref:`frozen`) in this case. method : 'sigmoid' or 'isotonic' The method to use for calibration. Can be 'sigmoid' which diff --git a/sklearn/ensemble/tests/test_voting_classifier.py b/sklearn/ensemble/tests/test_voting_classifier.py index 4765d0e32d0bb..93967766cabd0 100644 --- a/sklearn/ensemble/tests/test_voting_classifier.py +++ b/sklearn/ensemble/tests/test_voting_classifier.py @@ -367,6 +367,9 @@ def test_estimator_weights_format(): assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) +def test_frozen(): + raise NotImplementedError() + def test_transform(): """Check transform method of VotingClassifier on toy dataset.""" clf1 = LogisticRegression(random_state=123) diff --git a/sklearn/ensemble/voting_classifier.py b/sklearn/ensemble/voting_classifier.py index 88b329d836978..aa2725e547950 100644 --- a/sklearn/ensemble/voting_classifier.py +++ b/sklearn/ensemble/voting_classifier.py @@ -17,6 +17,7 @@ from ..base import ClassifierMixin from ..base import TransformerMixin from ..base import clone +from ..base import frozen_fit from ..preprocessing import LabelEncoder from ..externals.joblib import Parallel, delayed from ..utils.validation import has_fit_parameter, check_is_fitted @@ -26,9 +27,9 @@ def _parallel_fit_estimator(estimator, X, y, sample_weight): """Private function used to fit an estimator within a job.""" if sample_weight is not None: - estimator.fit(X, y, sample_weight) + frozen_fit(estimator, 'fit', X, y, sample_weight=sample_weight) else: - estimator.fit(X, y) + frozen_fit(estimator, 'fit', X, y) return estimator @@ -47,6 +48,8 @@ class VotingClassifier(_BaseComposition, ClassifierMixin, TransformerMixin): ``self.estimators_``. An estimator can be set to `None` using ``set_params``. + Some of these estimators may be frozen (see :ref:`frozen`). + voting : str, {'hard', 'soft'} (default='hard') If 'hard', uses predicted class labels for majority rule voting. Else if 'soft', predicts the class label based on the argmax of diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 2502643453d79..85984fad3e55a 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -1,10 +1,12 @@ # Authors: Gilles Louppe, Mathieu Blondel, Maheshakya Wijewardena # License: BSD 3 clause +import warnings + import numpy as np from .base import SelectorMixin -from ..base import BaseEstimator, clone, MetaEstimatorMixin +from ..base import BaseEstimator, clone, MetaEstimatorMixin, frozen_fit from ..externals import six from ..exceptions import NotFittedError @@ -86,9 +88,10 @@ class SelectFromModel(BaseEstimator, SelectorMixin, MetaEstimatorMixin): ---------- estimator : object The base estimator from which the transformer is built. - This can be both a fitted (if ``prefit`` is set to True) - or a non-fitted estimator. The estimator must have either a - ``feature_importances_`` or ``coef_`` attribute after fitting. + The estimator must have either a ``feature_importances_`` + or ``coef_`` attribute after fitting. + + This estimator may be frozen (see :ref:`frozen`). threshold : string, float, optional default None The threshold value to use for feature selection. Features whose @@ -100,14 +103,6 @@ class SelectFromModel(BaseEstimator, SelectorMixin, MetaEstimatorMixin): or implicitly (e.g, Lasso), the threshold used is 1e-5. Otherwise, "mean" is used by default. - prefit : bool, default False - Whether a prefit model is expected to be passed into the constructor - directly or not. If True, ``transform`` must be called directly - and SelectFromModel cannot be used with ``cross_val_score``, - ``GridSearchCV`` and similar utilities that clone the estimator. - Otherwise train the model using ``fit`` and then ``transform`` to do - feature selection. - norm_order : non-zero int, inf, -inf, default 1 Order of the norm used to filter the vectors of coefficients below ``threshold`` in the case where the ``coef_`` attribute of the @@ -123,22 +118,18 @@ class SelectFromModel(BaseEstimator, SelectorMixin, MetaEstimatorMixin): threshold_ : float The threshold value used for feature selection. """ - def __init__(self, estimator, threshold=None, prefit=False, norm_order=1): + def __init__(self, estimator, threshold=None, prefit=None, norm_order=1): self.estimator = estimator self.threshold = threshold self.prefit = prefit self.norm_order = norm_order def _get_support_mask(self): - # SelectFromModel can directly call on transform. if self.prefit: estimator = self.estimator - elif hasattr(self, 'estimator_'): - estimator = self.estimator_ else: - raise ValueError( - 'Either fit SelectFromModel before transform or set "prefit=' - 'True" and pass a fitted estimator to the constructor.') + from ..utils.validation import check_is_fitted + check_is_fitted(self, 'estimator_') scores = _get_feature_importances(estimator, self.norm_order) threshold = _calculate_threshold(estimator, scores, self.threshold) return scores >= threshold @@ -162,11 +153,15 @@ def fit(self, X, y=None, **fit_params): self : object Returns self. """ + if self.prefit is not None: + warnings.warn('Parameter prefit is deprecated and will be removed ' + 'in version 0.22. Set estimator.frozen = True ' + 'instead') if self.prefit: raise NotFittedError( "Since 'prefit=True', call transform directly") self.estimator_ = clone(self.estimator) - self.estimator_.fit(X, y, **fit_params) + frozen_fit(self.estimator_, 'fit', X, y, **fit_params) return self @property diff --git a/sklearn/feature_selection/tests/test_from_model.py b/sklearn/feature_selection/tests/test_from_model.py index ae4d1ba4331a6..c25ef5cbe651c 100644 --- a/sklearn/feature_selection/tests/test_from_model.py +++ b/sklearn/feature_selection/tests/test_from_model.py @@ -183,3 +183,7 @@ def test_threshold_without_refitting(): # Set a higher threshold to filter out more features. model.threshold = "1.0 * mean" assert_greater(X_transform.shape[1], model.transform(data).shape[1]) + + +def test_frozen(): + raise NotImplementedError() diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index a47c5f48f2fe2..e13d30dc01a72 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -14,7 +14,7 @@ import numpy as np from scipy import sparse -from .base import clone, TransformerMixin +from .base import clone, TransformerMixin, frozen_fit from .externals.joblib import Parallel, delayed, Memory from .externals import six from .utils import tosequence @@ -52,6 +52,8 @@ class Pipeline(_BaseComposition): chained, in the order in which they are chained, with the last object an estimator. + Some of these estimators may be frozen (see :ref:`frozen`). + memory : Instance of sklearn.external.joblib.Memory or string, optional \ (default=None) Used to cache the fitted transformers of the pipeline. By default, @@ -256,7 +258,7 @@ def fit(self, X, y=None, **fit_params): """ Xt, fit_params = self._fit(X, y, **fit_params) if self._final_estimator is not None: - self._final_estimator.fit(Xt, y, **fit_params) + frozen_fit(self._final_estimator, 'fit', Xt, y, **fit_params) return self def fit_transform(self, X, y=None, **fit_params): @@ -289,11 +291,12 @@ def fit_transform(self, X, y=None, **fit_params): last_step = self._final_estimator Xt, fit_params = self._fit(X, y, **fit_params) if hasattr(last_step, 'fit_transform'): - return last_step.fit_transform(Xt, y, **fit_params) + return frozen_fit(last_step, 'fit_transform', Xt, y, **fit_params) elif last_step is None: return Xt else: - return last_step.fit(Xt, y, **fit_params).transform(Xt) + return frozen_fit(last_step, 'fit', Xt, y, + **fit_params).transform(Xt) @if_delegate_has_method(delegate='_final_estimator') def predict(self, X): @@ -536,7 +539,8 @@ def make_pipeline(*steps, **kwargs): Parameters ---------- - *steps : list of estimators, + *steps : list of estimators + Some of these estimators may be frozen (see :ref:`frozen`). memory : Instance of sklearn.externals.joblib.Memory or string, optional \ (default=None) @@ -572,7 +576,7 @@ def make_pipeline(*steps, **kwargs): def _fit_one_transformer(transformer, X, y): - return transformer.fit(X, y) + return frozen_fit(transformer, 'fit', X, y, **fit_params) def _transform_one(transformer, weight, X): @@ -586,9 +590,9 @@ def _transform_one(transformer, weight, X): def _fit_transform_one(transformer, weight, X, y, **fit_params): if hasattr(transformer, 'fit_transform'): - res = transformer.fit_transform(X, y, **fit_params) + res = frozen_fit(transformer, 'fit_transform', X, y, **fit_params) else: - res = transformer.fit(X, y, **fit_params).transform(X) + res = frozen_fit(transformer, 'fit', X, y, **fit_params).transform() # if we have a weight for this transformer, multiply output if weight is None: return res, transformer @@ -615,6 +619,8 @@ class FeatureUnion(_BaseComposition, TransformerMixin): List of transformer objects to be applied to the data. The first half of each tuple is the name of the transformer. + Some of these transformers may be frozen (see :ref:`frozen`). + n_jobs : int, optional Number of jobs to run in parallel (default 1). @@ -800,6 +806,7 @@ def make_union(*transformers, **kwargs): Parameters ---------- *transformers : list of estimators + Some of these transformers may be frozen (see :ref:`frozen`). n_jobs : int, optional Number of jobs to run in parallel (default 1). diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index 948d5818b9b0e..7e93ca280104c 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -26,7 +26,7 @@ from sklearn import datasets from sklearn.utils import deprecated -from sklearn.base import TransformerMixin +from sklearn.base import TransformerMixin, frozen_fit from sklearn.utils.mocking import MockDataFrame import pickle @@ -188,6 +188,32 @@ def test_clone_sparse_matrices(): assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray()) +def test_clone_frozen(): + est = DecisionTreeClassifier() + est.fit([[0], [1]], [0, 1]) + assert clone(est) is not est + est.frozen = True + assert clone(est) is est + # is still fitted + assert est.predict([[0]]) == 0 + # freezing works recursively + seq = [est] + assert clone(seq) is not seq + assert clone(seq)[0] is est + + # freezing first then fitting works too + est = DecisionTreeClassifier() + est.frozen = True + est.fit([[0], [1]], [0, 1]) + assert clone(est) is est + assert est.predict([[0]]) == 0 + + # can freeze an unfitted estimator (not sure why, but worth testing) + est = DecisionTreeClassifier() + est.frozen = True + assert clone(est) is est + + def test_repr(): # Smoke test the repr of the base estimator. my_estimator = MyEstimator() @@ -450,3 +476,75 @@ def test_pickling_works_when_getstate_is_overwritten_in_the_child_class(): estimator_restored = pickle.loads(serialized) assert_equal(estimator_restored.attribute_pickled, 5) assert_equal(estimator_restored._attribute_not_pickled, None) + + +def test_frozen_fit(): + class DummyEstimator(BaseEstimator): + def fit(self, X, y, **kwargs): + self.X_ = X + self.y_ = y + self.kwargs_ = kwargs + self._last_call = 'fit' + return self + + def fit_transform(self, X, y, **kwargs): + self.X_ = X + self.y_ = y + self.kwargs_ = kwargs + self._last_call = 'fit_transform' + return np.array(self.X_[0]) + X + + def transform(self, X): + self._last_call = 'transform' + return np.array(self.X_[0]) + X + + def fit_wobble(self, X, y, **kwargs): + self.X_ = X + self.y_ = y + self.kwargs_ = kwargs + self._last_call = 'fit_wobble' + return self.y_[0] + X[0][0] + + def wobble(self, X): + self._last_call = 'wobble' + return self.y_[0] + X[0][0] + + X_freeze = [[5]] + y_freeze = [-1] + z_freeze = [0] + X_train = [[10]] + y_train = [1] + z_train = [0] + + for fit_method, method in [('fit', None), + ('fit_transform', 'transform'), + ('fit_wobble', 'wobble')]: + + # est is not frozen + est = DummyEstimator().fit(X_freeze, y_freeze, z=z_freeze) + + result = frozen_fit(est, fit_method, X_train, y_train, z=z_train) + # check it called .fit_transform(), not .fit().transform(), for example + assert est._last_call == fit_method + # check model was re-fit + assert est.X_ == X_train + assert est.y_ == y_train + assert est.kwargs_ == {'z': z_train} + if fit_method == 'fit': + assert result is est + else: + assert_array_equal(result, getattr(est, method)(X_train)) + + # est is not frozen + est = DummyEstimator().fit(X_freeze, y_freeze, z=z_freeze) + est.frozen = True + result = frozen_fit(est, fit_method, X_train, y_train, z=z_train) + # check model was not re-fit + assert est.X_ == X_freeze + assert est.y_ == y_freeze + assert est.kwargs_ == {'z': z_freeze} + if fit_method == 'fit': + assert result is est + else: + assert est._last_call == method + assert_array_equal(result, getattr(est, method)(X_train)) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index 2549d84dfcea5..38e51c47ea5dc 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -928,3 +928,11 @@ def test_make_pipeline_memory(): assert_true(pipeline.memory is None) shutil.rmtree(cachedir) + + +def test_pipeline_frozen(): + raise NotImplementedError() + + +def test_feature_union_frozen(): + raise NotImplementedError()