diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5d72f81182424..95eb096a0f324 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1,5 +1,6 @@ from __future__ import print_function +import types import warnings import sys import traceback @@ -12,12 +13,13 @@ import struct from sklearn.externals.six.moves import zip -from sklearn.externals.joblib import hash +from sklearn.externals.joblib import hash, Memory from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_raises_regex from sklearn.utils.testing import assert_raise_message from sklearn.utils.testing import assert_equal from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_in from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal from sklearn.utils.testing import assert_warns_message @@ -1241,6 +1243,8 @@ def check_parameters_default_constructible(name, Estimator): else: return for arg, default in zip(args, defaults): + assert_in(type(default), [str, int, float, bool, tuple, type(None), + np.float64, types.FunctionType, Memory]) if arg not in params.keys(): # deprecated parameter, not in get_params assert_true(default is None)