From 4acf6f41729dd9f5abfae21fe8a1faf6f3990e49 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 17 Oct 2021 23:34:22 -0400 Subject: [PATCH 1/6] ENH Smoke test for invalid parameters in __init__ and get_params --- sklearn/tests/test_common.py | 45 +++++++++++++++++++++++++++++++++ sklearn/utils/metaestimators.py | 6 ++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 4f6818081c67d..15239a939dd8f 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -16,6 +16,7 @@ from functools import partial import pytest +import numpy as np from sklearn.utils import all_estimators from sklearn.utils._testing import ignore_warnings @@ -403,3 +404,47 @@ def test_transformers_get_feature_names_out(transformer): check_transformer_get_feature_names_out_pandas( transformer.__class__.__name__, transformer ) + + +VALIDATE_ESTIMATOR_INIT = [ + "ColumnTransformer", + "FactorAnalysis", + "FastICA", + "FeatureHasher", + "FeatureUnion", + "GridSearchCV", + "HalvingGridSearchCV", + "KernelDensity", + "KernelPCA", + "LabelBinarizer", + "NuSVC", + "NuSVR", + "OneClassSVM", + "Pipeline", + "RadiusNeighborsClassifier", + "SGDOneClassSVM", + "SVC", + "SVR", + "TheilSenRegressor", + "TweedieRegressor", +] +VALIDATE_ESTIMATOR_INIT = set(VALIDATE_ESTIMATOR_INIT) + + +@pytest.mark.parametrize( + "Estimator", + [est for name, est in all_estimators() if name not in VALIDATE_ESTIMATOR_INIT], +) +def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): + """Check that init or set_param does not raise errors.""" + params = signature(Estimator).parameters + + smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []] + for value in smoke_test_values: + new_params = {key: value for key in params} + + # Does not raise + est = Estimator(**new_params) + + # Also do does not raise + est.set_params(**new_params) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 5d71d28c5ffab..50307dc9d0094 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -29,7 +29,11 @@ def _get_params(self, attr, deep=True): out = super().get_params(deep=deep) if not deep: return out + estimators = getattr(self, attr) + if not isinstance(out, dict) or not isinstance(estimators, dict): + return out + out.update(estimators) for name, estimator in estimators: if hasattr(estimator, "get_params"): @@ -45,7 +49,7 @@ def _set_params(self, attr, **params): # 2. Step replacement items = getattr(self, attr) names = [] - if items: + if isinstance(items, list) and items: names, _ = zip(*items) for name in list(params.keys()): if "__" not in name and name in names: From 471532e9c1c790dd00364999492539613aa13182 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 19 Oct 2021 12:22:43 -0400 Subject: [PATCH 2/6] DOC Improve comments --- sklearn/utils/metaestimators.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 50307dc9d0094..8cb607fa682bd 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -46,14 +46,15 @@ def _set_params(self, attr, **params): # 1. All steps if attr in params: setattr(self, attr, params.pop(attr)) - # 2. Step replacement + # 2. Replace items with estimators in params items = getattr(self, attr) - names = [] if isinstance(items, list) and items: - names, _ = zip(*items) - for name in list(params.keys()): - if "__" not in name and name in names: - self._replace_estimator(attr, name, params.pop(name)) + # Get item names used to identify valid names in params + item_names, _ = zip(*items) + for name in list(params.keys()): + if "__" not in name and name in item_names: + self._replace_estimator(attr, name, params.pop(name)) + # 3. Step parameters and other initialisation arguments super().set_params(**params) return self From a7bfccfaf91528ecf1128b36b85184eee80ddd5e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 19 Oct 2021 16:13:07 -0400 Subject: [PATCH 3/6] FIX Fixes bug --- sklearn/utils/metaestimators.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 8cb607fa682bd..b73e7f1e90381 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -31,10 +31,14 @@ def _get_params(self, attr, deep=True): return out estimators = getattr(self, attr) - if not isinstance(out, dict) or not isinstance(estimators, dict): + if not isinstance(out, dict): + return out + + try: + out.update(estimators) + except (TypeError, ValueError): return out - out.update(estimators) for name, estimator in estimators: if hasattr(estimator, "get_params"): for key, value in estimator.get_params(deep=True).items(): From 2dea41f14986f1e99e830db91492bd5c665fbe5f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 20 Oct 2021 13:27:23 -0400 Subject: [PATCH 4/6] CLN Remove unneeded code --- sklearn/utils/metaestimators.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index b73e7f1e90381..1d56395758564 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -31,9 +31,6 @@ def _get_params(self, attr, deep=True): return out estimators = getattr(self, attr) - if not isinstance(out, dict): - return out - try: out.update(estimators) except (TypeError, ValueError): From cc58c4b26f3a60449facef4643fbc686aae17437 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 21 Oct 2021 09:12:31 -0400 Subject: [PATCH 5/6] DOC Adds docstring --- sklearn/utils/metaestimators.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 1d56395758564..142464fac2422 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -33,7 +33,11 @@ def _get_params(self, attr, deep=True): estimators = getattr(self, attr) try: out.update(estimators) - except (TypeError, ValueError): + except TypeError: + # Here we ignore TypeError for cases where estimators is not a list of + # (name, estimator). This is to prevent errors when calling `set_params`. + # `BaseEstimator.set_params` calls `get_params` which can error + # for invalid values for `estimators`. return out for name, estimator in estimators: From cd4f1958559d529d306aa0378c21e4b92c1d617c Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 21 Oct 2021 09:21:59 -0400 Subject: [PATCH 6/6] DOC Adds ValueError in comment --- sklearn/utils/metaestimators.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/metaestimators.py b/sklearn/utils/metaestimators.py index 142464fac2422..eb4d15a29e18d 100644 --- a/sklearn/utils/metaestimators.py +++ b/sklearn/utils/metaestimators.py @@ -33,11 +33,12 @@ def _get_params(self, attr, deep=True): estimators = getattr(self, attr) try: out.update(estimators) - except TypeError: - # Here we ignore TypeError for cases where estimators is not a list of - # (name, estimator). This is to prevent errors when calling `set_params`. - # `BaseEstimator.set_params` calls `get_params` which can error - # for invalid values for `estimators`. + except (TypeError, ValueError): + # Ignore TypeError for cases where estimators is not a list of + # (name, estimator) and ignore ValueError when the list is not + # formated correctly. This is to prevent errors when calling + # `set_params`. `BaseEstimator.set_params` calls `get_params` which + # can error for invalid values for `estimators`. return out for name, estimator in estimators: