Skip to content

ENH allow extra params to be copied in clone #20681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,11 @@ X_types (default=['2darray'])
``'categorical'`` data. For now, the test for sparse data do not make use
of the ``'sparse'`` tag.

non_init_params (default=[])
This includes the list of names of the parameters which need to be copied in
``clone``, but are not included in the output of ``get_params`` and are not
an ``__init__`` parameter.

It is unlikely that the default values for each tag will suit the needs of your
specific estimator. Additional tags can be created or default tags can be
overridden by defining a `_more_tags()` method which returns a dict with the
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ Changelog
:mod:`sklearn.base`
...................

- |Enhancement| :func:`clone` now copies extra parameters which are listed in
``tags["non_init_params"]``, and ``check_no_attributes_set_in_init`` tolerates
them being copied. :pr:`20681` by `Adrin Jalali`_.

- |Fix| :func:`config_context` is now threadsafe. :pr:`18736` by `Thomas Fan`_.

:mod:`sklearn.calibration`
Expand Down
14 changes: 14 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def clone(estimator, *, safe=True):
without actually copying attached data. It yields a new estimator
with the same parameters that has not been fitted on any data.

If there are parameters which need to be copied but are not returned as the
output of ``get_params``, they can be included in
``estimator_tags["non_init_params"]`` as an iterable of their names.

If the estimator's `random_state` parameter is an integer (or if the
estimator doesn't have a `random_state` parameter), an *exact clone* is
returned: the clone and the original estimator will give the exact same
Expand Down Expand Up @@ -80,6 +84,16 @@ def clone(estimator, *, safe=True):
new_object = klass(**new_object_params)
params_set = new_object.get_params(deep=False)

# we then copy parameters which are requested to be copied in the
# "non_init_params" tag.
extra_params = _safe_tags(estimator).get("non_init_params", [])
for param in extra_params:
try:
setattr(new_object, param, clone(getattr(estimator, param), safe=False))
except AttributeError:
# we ignore parameters which are not present
pass

# quick sanity check of the parameters of the clone
for name in new_object_params:
param1 = new_object_params[name]
Expand Down
15 changes: 15 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ def test_clone_2():
assert not hasattr(new_selector, "own_attribute")


def test_clone_extra_params():
# Test that clone copies the parameters provided in tags["non_init_params"]
class Estimator(BaseEstimator):
def _more_tags(self):
return {"non_init_params": ["param1"]}

est = Estimator()
# test that clone works when "param1" is not present
est_copy = clone(est)
assert not hasattr(est_copy, "param1")
est.param1 = 42
est_copy = clone(est)
assert est_copy.param1 == 42


def test_clone_buggy():
# Check that clone raises an error on buggy estimators.
buggy = Buggy()
Expand Down
1 change: 1 addition & 0 deletions sklearn/utils/_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"preserves_dtype": [np.float64],
"requires_y": False,
"pairwise": False,
"non_init_params": [],
}


Expand Down
9 changes: 7 additions & 2 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2750,8 +2750,13 @@ def check_no_attributes_set_in_init(name, estimator_orig):
for param in params_parent
]

# Test for no setting apart from parameters during init
invalid_attr = set(vars(estimator)) - set(init_params) - set(parents_init_params)
# Test for no setting apart from parameters and "non_init_params" during init
invalid_attr = (
set(vars(estimator))
- set(init_params)
- set(parents_init_params)
- set(_safe_tags(estimator).get("non_init_params", []))
)
assert not invalid_attr, (
"Estimator %s should not set any attribute apart"
" from parameters during init. Found attributes %s."
Expand Down
12 changes: 12 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,13 @@ class NonConformantEstimatorNoParamSet(BaseEstimator):
def __init__(self, you_should_set_this_=None):
pass

class ConformantEstimatorWithExtraParams(BaseEstimator):
def _more_tags(self):
return {"non_init_params": ["param1"]}

def __init__(self):
self.param1 = 42

msg = (
"Estimator estimator_name should not set any"
" attribute apart from parameters during init."
Expand All @@ -647,6 +654,11 @@ def __init__(self, you_should_set_this_=None):
"estimator_name", NonConformantEstimatorNoParamSet()
)

# this shouldn't raise any errors
check_no_attributes_set_in_init(
"conformant_estimator", ConformantEstimatorWithExtraParams()
)


def test_check_estimator_pairwise():
# check that check_estimator() works on estimator with _pairwise
Expand Down