diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 3c9f6371e0a1e..bbfd12ad39b9c 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1,4 +1,3 @@ -import types import warnings import pickle import re @@ -3297,18 +3296,25 @@ def param_filter(p): tuple, type(None), type, - types.FunctionType, - joblib.Memory, } # Any numpy numeric such as np.int32. allowed_types.update(np.core.numerictypes.allTypes.values()) - assert type(init_param.default) in allowed_types, ( + + allowed_value = ( + type(init_param.default) in allowed_types + or + # Although callables are mutable, we accept them as argument + # default value and trust that neither the implementation of + # the callable nor of the estimator changes the state of the + # callable. + callable(init_param.default) + ) + + assert allowed_value, ( f"Parameter '{init_param.name}' of estimator " f"'{Estimator.__name__}' is of type " - f"{type(init_param.default).__name__} which is not " - "allowed. All init parameters have to be immutable to " - "make cloning possible. Therefore we restrict the set of " - "legal types to " + f"{type(init_param.default).__name__} which is not allowed. " + f"'{init_param.name}' must be a callable or must be of type " f"{set(type.__name__ for type in allowed_types)}." ) if init_param.name not in params.keys():