-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] FIX bug in nested set_params usage #9999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] FIX bug in nested set_params usage #9999
Conversation
Issue where estimator is changed as well as its parameter: scikit-learn#9945 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpicks but otherwise +1
sklearn/tests/test_base.py
Outdated
def test_set_params_updates_valid_params(): | ||
# Check that set_params tries to set SVC().C, not | ||
# DecisionTreeClassifier().C | ||
pipe = GridSearchCV(DecisionTreeClassifier(), {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why naming it pipe
instead of something like gridsearchcv
?
sklearn/tests/test_base.py
Outdated
# Check that set_params tries to set SVC().C, not | ||
# DecisionTreeClassifier().C | ||
pipe = GridSearchCV(DecisionTreeClassifier(), {}) | ||
pipe.set_params(estimator=SVC(), estimator__C=1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please extend this test to do estimator__C=42.0
and then assert gridsearchcv.estimator.C == 42.0
?
sklearn/tests/test_base.py
Outdated
def test_set_params_updates_valid_params(): | ||
# Check that set_params tries to set SVC().C, not | ||
# DecisionTreeClassifier().C | ||
pipe = GridSearchCV(DecisionTreeClassifier(), {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pipe
may not be the best of names ;-)
Sorry, I started with a Pipeline then realised it's not affected!
…On 25 October 2017 at 20:41, Loïc Estève ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In sklearn/tests/test_base.py
<#9999 (comment)>
:
> @@ -246,6 +246,13 @@ def set_params(self, **kwargs):
estimator__min_samples_leaf=2)
+def test_set_params_updates_valid_params():
+ # Check that set_params tries to set SVC().C, not
+ # DecisionTreeClassifier().C
+ pipe = GridSearchCV(DecisionTreeClassifier(), {})
pipe may not be the best of names ;-)
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#9999 (review)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz625DDrChwDCU8H3cTxegBIDTnRkOks5svwJCgaJpZM4QFsj->
.
|
Thanks for the quick reviews!
|
In the SO question, the OP mentions that this diff is working for him: diff --git a/sklearn/base.py b/sklearn/base.py
index b653b7149..81c7e5dae 100644
--- a/sklearn/base.py
+++ b/sklearn/base.py
@@ -263,6 +263,7 @@ class BaseEstimator(object):
nested_params[key][sub_key] = value
else:
setattr(self, key, value)
+ valid_params[key] = value
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params) I checked indeed that the test passes with this patch which looked simpler than your current change. You know a lot more than me about this code, so there may be a reason why your current change is the way it is. Maybe there is a edge case that the simpler patch is not covering, if that is the case, it would be great to add a test for it. |
sorry. I didn't see that on my phone browser. yes it's a better patch. will
fix. why didn't op make a pr??
…On 25 Oct 2017 8:48 pm, "Loïc Estève" ***@***.***> wrote:
In the SO question, the OP mentions that this diff is working for him:
diff --git a/sklearn/base.py b/sklearn/base.py
index b653b7149..81c7e5dae 100644
--- a/sklearn/base.py
+++ b/sklearn/base.py
@@ -263,6 +263,7 @@ class BaseEstimator(object):
nested_params[key][sub_key] = value
else:
setattr(self, key, value)
+ valid_params[key] = value
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
I checked indeed that the test passes with this patch which looked simpler
than your current change. You know a lot more than me about this code, so
there may be a reason why your current change is the way it is. Maybe there
is a edge case that the simpler patch is not covering, if that is the case,
it would be great to add a test for it.
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#9999 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AAEz6ywoqnTFhgZV5px4Mlmwzyi84OoDks5svwPggaJpZM4QFsj->
.
|
Hey, the OP of SO here. I could PR it, but as I am not very experienced with such open source work I was kinda intimidated by simply adding the PR. As @lesteve mentions, who knows what edge cases I may be breaking. |
I should probably have let you post a PR. But yes, contributors tend to be
scared of this part of the codebase...
Please do contribute in the future. Your SO post is the sort of stuff
dreams are made of in the FOSS world; if only it were posted as a PR!
I'd like to think that edge cases are tested cases.... :\
|
LGTM as well, thanks @marcus-voss for the clean fix ! |
Thanks all! |
…cs/donigian-update-contribution-guidelines * 'master' of github.com:scikit-learn/scikit-learn: (23 commits) fixes scikit-learn#10031: fix attribute name and shape in documentation (scikit-learn#10033) [MRG+1] add changelog entry for fixed and merged PR scikit-learn#10005 issue scikit-learn#9633 (scikit-learn#10025) [MRG] Fix LogisticRegression see also should include LogisticRegressionCV(scikit-learn#9995) (scikit-learn#10022) [MRG + 1] Labels of clustering should start at 0 or -1 if noise (scikit-learn#10015) MAINT Remove redundancy in scikit-learn#9552 (scikit-learn#9573) [MRG+1] correct comparison in GaussianNB for 'priors' (scikit-learn#10005) [MRG + 1] ENH add check_inverse in FunctionTransformer (scikit-learn#9399) [MRG] FIX bug in nested set_params usage (scikit-learn#9999) [MRG+1] Fix LOF and Isolation benchmarks (scikit-learn#9798) [MRG + 1] Fix negative inputs checking in mean_squared_log_error (scikit-learn#9968) DOC Fix typo (scikit-learn#9996) DOC Fix typo: x axis -> y axis (scikit-learn#9985) improve example plot_forest_iris.py (scikit-learn#9989) [MRG+1] Deprecate pooling_func unused parameter in AgglomerativeClustering (scikit-learn#9875) DOC update news DOC Fix three typos in manifold documentation (scikit-learn#9990) DOC add missing dot in docstring DOC Add what's new for 0.19.1 (scikit-learn#9983) Improve readability of outlier detection example. (scikit-learn#9973) DOC: Fixed typo (scikit-learn#9977) ...
Issue where estimator is changed as well as its parameter: #9945 (comment)