-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FIX Enable setting of sub-parameters for deprecated base_estimator param #25477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
698ddac
3d6d13c
2b13d8b
88407d5
a331227
89ee48f
8be004e
5c4b441
d20c386
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -227,6 +227,25 @@ def set_params(self, **params): | |
valid_params[key] = value | ||
|
||
for key, sub_params in nested_params.items(): | ||
# TODO(1.4): remove specific handling of "base_estimator". | ||
# The "base_estimator" key is special. It was deprecated and | ||
# renamed to "estimator" for several estimators. This means we | ||
# need to translate it here and set sub-parameters on "estimator", | ||
# but only if the user did not explicitly set a value for | ||
# "base_estimator". | ||
if ( | ||
key == "base_estimator" | ||
and valid_params[key] == "deprecated" | ||
and self.__module__.startswith("sklearn.") | ||
): | ||
warnings.warn( | ||
f"Parameter 'base_estimator' of {self.__class__.__name__} is" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not actionable: It would be harder for third party estimators to deprecate a positional parameter because they would also need adjust |
||
" deprecated in favor of 'estimator'. See" | ||
f" {self.__class__.__name__}'s docstring for more details.", | ||
FutureWarning, | ||
stacklevel=2, | ||
) | ||
key = "estimator" | ||
valid_params[key].set_params(**sub_params) | ||
|
||
return self | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -630,3 +630,21 @@ def test_base_estimator_property_deprecated(AdaBoost): | |
) | ||
with pytest.warns(FutureWarning, match=warn_msg): | ||
model.base_estimator_ | ||
|
||
|
||
# TODO(1.4): remove in 1.4 | ||
def test_deprecated_base_estimator_parameters_can_be_set(): | ||
"""Check that setting base_estimator parameters works. | ||
|
||
During the deprecation cycle setting "base_estimator__*" params should | ||
work. | ||
|
||
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/25470 | ||
""" | ||
# This implicitly sets "estimator", it is how old code (pre v1.2) would | ||
# have instantiated AdaBoostClassifier and back then it would set | ||
# "base_estimator". | ||
clf = AdaBoostClassifier(DecisionTreeClassifier()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually here we are setting It could also be work to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this test is enough. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is indeed what happens. I also think the warning at |
||
|
||
with pytest.warns(FutureWarning, match="Parameter 'base_estimator' of"): | ||
clf.set_params(base_estimator__max_depth=2) |
Uh oh!
There was an error while loading. Please reload this page.