-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Accept keyword parameters to hyperparameter search fit methods #8278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Accept keyword parameters to hyperparameter search fit methods #8278
Conversation
Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``.
Thanks Stephen. I'm low on time at the moment, and will work through my
backlog eventually!
…On 4 Feb 2017 1:19 am, "Stephen Hoover" ***@***.***> wrote:
@jnothman <https://github.com/jnothman> , this is my minimal fix to #2879
<#2879> .
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#8278 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz6yG7gsd6r1ZiBwl8VixrUInjN7Ajks5rYzdegaJpZM4L10eQ>
.
|
sklearn/utils/mocking.py
Outdated
@@ -45,19 +45,30 @@ class CheckingClassifier(BaseEstimator, ClassifierMixin): | |||
changed the input. | |||
""" | |||
def __init__(self, check_y=None, | |||
check_X=None, foo_param=0): | |||
check_X=None, foo_param=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cosmetic: put more args on the first line as long as they fit under 80 columns
sklearn/utils/mocking.py
Outdated
assert_true(len(X) == len(y)) | ||
if self.check_X is not None: | ||
assert_true(self.check_X(X)) | ||
if self.check_y is not None: | ||
assert_true(self.check_y(y)) | ||
self.classes_ = np.unique(check_array(y, ensure_2d=False, | ||
allow_nd=True)) | ||
if self.expected_fit_params: | ||
missing = (set(np.atleast_1d(self.expected_fit_params)) - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why converting to an numpy array with np.atleast_1d
before converting to a set? Converting the list to a set is enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The np.atleast_1d
call converts a string to a one-element array, so that users could give a string input to expected_fit_params
instead of a one-element list.
In this case, without the array casting there would be an obvious and loud error if someone forgot to make that input a list, and there's not much benefit to saving two characters for a test helper. I'll remove the np.atleast_1d
.
Please add a new test to check that the deprecation warning is raised as expected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also remove the fit_params parameter from the class docstrings and do a review of the full code base with git grep fit_params=
to check that fit_params are no longer passed as input to the parameter search class constructors and update the code to pass those to the fit method instead.
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) | ||
|
||
# The CheckingClassifer generates an assertion error if | ||
# a parameter is missing or has length != len(X). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of putting a comment to explain what CheckingClassifier should raise in that case, please add an assertion with assert_raises
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider assert_raises_regex
to check the error message on top of checking the exception type.
random_search = RandomizedSearchCV(clf, {'foo_param': [0]}, n_iter=1) | ||
|
||
# The CheckingClassifer generates an assertion error if | ||
# a parameter is missing or has length != len(X). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add an assertion with assert_raises
here.
LGTM besides the requested changes. |
Codecov Report
@@ Coverage Diff @@
## master #8278 +/- ##
=========================================
Coverage ? 94.74%
=========================================
Files ? 342
Lines ? 60711
Branches ? 0
=========================================
Hits ? 57519
Misses ? 3192
Partials ? 0
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much, LGTM.
I am +1 to merge. This will be modified by the sample_props, but is it indeed going in the right direction in terms of API. |
scikit-learn#8278) * ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
Does |
@ecampana , yes, |
@stephen-hoover, thank you for your reply and confirming that |
scikit-learn#8278) * ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
scikit-learn#8278) * ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
scikit-learn#8278) * ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
scikit-learn#8278) * ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
scikit-learn#8278) * ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
Reference Issue
Closes #2879
What does this implement/fix? Explain your changes.
This PR allows the hyperparameter search classes in
sklearn.model_selection._search
(i.e.GridSearchCV
andRandomizedSearchCV
) to accept keyword parameters and pass them to the wrapped Estimator. This brings theirfit
methods into closer compliance with the Estimator API. This PR does not address issues #8158 or #4632 which deal with passing parameters to the scoring function.To ease testing of the new parameter, I added an extra argument to the
sklearn.utils.mocking.CheckingClassifier
which allows it to enforce the presence of keywordfit
arguments and verify that they have the expected length.Any other comments?
The ultimate fate of fit parameters on these methods will depend on the resolution to #4497, but I think the change in this PR conforms to present usage without establishing any new conventions. Accepting keyword parameters in these
fit
methods is something that I (and I think other users) would expect.One limitation of this PR is that the
groups
parameter is still reserved by thefit
method and not passed to the wrapped Estimator. Perhaps a solution to that would wait on #4497, but I think it doesn't need to be solved in this PR.I also had to choose what to do if users try to provide fit parameters through both the constructor and the
fit
method. I chose to emit a warning and ignore parameters passed through the constructor (since that behavior is deprecated), but I don't feel strongly about that decision.I searched through existing documentation and tests, but I was unable to find any uses of the
fit_params
constructor argument for eitherGridSearchCV
orRandomizedSearchCV
.