-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH estimator freezing #9397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH estimator freezing #9397
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,8 @@ def clone(estimator, safe=True): | |
|
||
""" | ||
estimator_type = type(estimator) | ||
if getattr(estimator, 'frozen', False): | ||
return estimator | ||
# XXX: not handling dictionaries | ||
if estimator_type in (list, tuple, set, frozenset): | ||
return estimator_type([clone(e, safe=safe) for e in estimator]) | ||
|
@@ -578,3 +580,39 @@ def is_regressor(estimator): | |
True if estimator is a regressor and False otherwise. | ||
""" | ||
return getattr(estimator, "_estimator_type", None) == "regressor" | ||
|
||
|
||
def frozen_fit(estimator, method, X, y, **kwargs): | ||
"""Fit the estimator if not frozen, and return the result of method | ||
|
||
A frozen estimator has an attribute ``frozen`` set to True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It took me a bit to understand what this function is doing. Maybe expand a bit on the explanation? I'm not sure I have actually understood it, I think.... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about naming this "fit_if_not_frozen". Maybe that would help @amueller (and others) understand. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a better way to make it easier to understand is to remove |
||
|
||
Parameters | ||
---------- | ||
estimator | ||
method : str | ||
One of {'fit', 'fit_transform', 'fit_predict'} or similar. | ||
X | ||
y | ||
will only be passed when fitting | ||
kwargs | ||
will only be passed when fitting | ||
|
||
Returns | ||
------- | ||
out | ||
estimator if ``method == 'fit'``, else the output of ``transform`` etc. | ||
If the estimator has attribute ``frozen`` set to True, it will not be | ||
refit. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would benefit from an example. |
||
""" | ||
if getattr(estimator, 'frozen', False): | ||
if method == 'fit': | ||
return estimator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we should check it was fit before, but I guess that's hard and we defer to when it's used? |
||
if not method.startswith('fit_'): | ||
raise ValueError('method must be "fit" or begin with "fit_"') | ||
method = getattr(estimator, method[4:]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow this is a bit of a hack lol. And a model doesn't necessarily have that, right? If I call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't an attribute error what you should be getting, though? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it? The call is |
||
# FIXME: what do we do with kwargs? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would say: pass them along, to minimize surprise: that way, if people have coded an estimator that take extra arguments in fit_*, they get what they expect. I guess that that's an argument for sample_props and then actively battling arguments in fit_* |
||
return method(X) | ||
else: | ||
method = getattr(estimator, method) | ||
return method(X, y, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -367,6 +367,9 @@ def test_estimator_weights_format(): | |
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X)) | ||
|
||
|
||
def test_frozen(): | ||
raise NotImplementedError() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ?? Is this a TODO for you to address? |
||
|
||
def test_transform(): | ||
"""Check transform method of VotingClassifier on toy dataset.""" | ||
clf1 = LogisticRegression(random_state=123) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
import numpy as np | ||
from scipy import sparse | ||
|
||
from .base import clone, TransformerMixin | ||
from .base import clone, TransformerMixin, frozen_fit | ||
from .externals.joblib import Parallel, delayed, Memory | ||
from .externals import six | ||
from .utils import tosequence | ||
|
@@ -52,6 +52,8 @@ class Pipeline(_BaseComposition): | |
chained, in the order in which they are chained, with the last object | ||
an estimator. | ||
|
||
Some of these estimators may be frozen (see :ref:`frozen`). | ||
|
||
memory : Instance of sklearn.external.joblib.Memory or string, optional \ | ||
(default=None) | ||
Used to cache the fitted transformers of the pipeline. By default, | ||
|
@@ -256,7 +258,7 @@ def fit(self, X, y=None, **fit_params): | |
""" | ||
Xt, fit_params = self._fit(X, y, **fit_params) | ||
if self._final_estimator is not None: | ||
self._final_estimator.fit(Xt, y, **fit_params) | ||
frozen_fit(self._final_estimator, 'fit', Xt, y, **fit_params) | ||
return self | ||
|
||
def fit_transform(self, X, y=None, **fit_params): | ||
|
@@ -289,11 +291,12 @@ def fit_transform(self, X, y=None, **fit_params): | |
last_step = self._final_estimator | ||
Xt, fit_params = self._fit(X, y, **fit_params) | ||
if hasattr(last_step, 'fit_transform'): | ||
return last_step.fit_transform(Xt, y, **fit_params) | ||
return frozen_fit(last_step, 'fit_transform', Xt, y, **fit_params) | ||
elif last_step is None: | ||
return Xt | ||
else: | ||
return last_step.fit(Xt, y, **fit_params).transform(Xt) | ||
return frozen_fit(last_step, 'fit', Xt, y, | ||
**fit_params).transform(Xt) | ||
|
||
@if_delegate_has_method(delegate='_final_estimator') | ||
def predict(self, X): | ||
|
@@ -536,7 +539,8 @@ def make_pipeline(*steps, **kwargs): | |
|
||
Parameters | ||
---------- | ||
*steps : list of estimators, | ||
*steps : list of estimators | ||
Some of these estimators may be frozen (see :ref:`frozen`). | ||
|
||
memory : Instance of sklearn.externals.joblib.Memory or string, optional \ | ||
(default=None) | ||
|
@@ -572,7 +576,7 @@ def make_pipeline(*steps, **kwargs): | |
|
||
|
||
def _fit_one_transformer(transformer, X, y): | ||
return transformer.fit(X, y) | ||
return frozen_fit(transformer, 'fit', X, y, **fit_params) | ||
|
||
|
||
def _transform_one(transformer, weight, X): | ||
|
@@ -586,9 +590,9 @@ def _transform_one(transformer, weight, X): | |
def _fit_transform_one(transformer, weight, X, y, | ||
**fit_params): | ||
if hasattr(transformer, 'fit_transform'): | ||
res = transformer.fit_transform(X, y, **fit_params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we'd add a |
||
res = frozen_fit(transformer, 'fit_transform', X, y, **fit_params) | ||
else: | ||
res = transformer.fit(X, y, **fit_params).transform(X) | ||
res = frozen_fit(transformer, 'fit', X, y, **fit_params).transform() | ||
# if we have a weight for this transformer, multiply output | ||
if weight is None: | ||
return res, transformer | ||
|
@@ -615,6 +619,8 @@ class FeatureUnion(_BaseComposition, TransformerMixin): | |
List of transformer objects to be applied to the data. The first | ||
half of each tuple is the name of the transformer. | ||
|
||
Some of these transformers may be frozen (see :ref:`frozen`). | ||
|
||
n_jobs : int, optional | ||
Number of jobs to run in parallel (default 1). | ||
|
||
|
@@ -800,6 +806,7 @@ def make_union(*transformers, **kwargs): | |
Parameters | ||
---------- | ||
*transformers : list of estimators | ||
Some of these transformers may be frozen (see :ref:`frozen`). | ||
|
||
n_jobs : int, optional | ||
Number of jobs to run in parallel (default 1). | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to add
BaseEstimator.freeze
? it saves like 5 characters ;) But it also allows us to change the implementation details.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather not: having a "freeze()" method rather than a "frozen" attribute means that the logic is modifiable in subclasses: the contract is more open: "freeze()" could change more to the estimator. This means that it is potentialy harder to understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a method is helpful either. Only makes it harder to use something that for whatever strange reason does not inherit from base
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman hm that's a fair point. I'm just concerned that this will be very hard to change in the future, if we ever decide that's necessary.