Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import partial

import pytest
import numpy as np

from sklearn.utils import all_estimators
from sklearn.utils._testing import ignore_warnings
Expand Down Expand Up @@ -403,3 +404,47 @@ def test_transformers_get_feature_names_out(transformer):
check_transformer_get_feature_names_out_pandas(
transformer.__class__.__name__, transformer
)


VALIDATE_ESTIMATOR_INIT = [
"ColumnTransformer",
"FactorAnalysis",
"FastICA",
"FeatureHasher",
"FeatureUnion",
"GridSearchCV",
"HalvingGridSearchCV",
"KernelDensity",
"KernelPCA",
"LabelBinarizer",
"NuSVC",
"NuSVR",
"OneClassSVM",
"Pipeline",
"RadiusNeighborsClassifier",
"SGDOneClassSVM",
"SVC",
"SVR",
"TheilSenRegressor",
"TweedieRegressor",
]
VALIDATE_ESTIMATOR_INIT = set(VALIDATE_ESTIMATOR_INIT)


@pytest.mark.parametrize(
"Estimator",
[est for name, est in all_estimators() if name not in VALIDATE_ESTIMATOR_INIT],
)
def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"""Check that init or set_param does not raise errors."""
params = signature(Estimator).parameters

smoke_test_values = [-1, 3.0, "helloworld", np.array([1.0, 4.0]), {}, []]
for value in smoke_test_values:
new_params = {key: value for key in params}

# Does not raise
est = Estimator(**new_params)

# Also do does not raise
est.set_params(**new_params)
27 changes: 19 additions & 8 deletions sklearn/utils/metaestimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,18 @@ def _get_params(self, attr, deep=True):
out = super().get_params(deep=deep)
if not deep:
return out

estimators = getattr(self, attr)
out.update(estimators)
try:
out.update(estimators)
except (TypeError, ValueError):
# Ignore TypeError for cases where estimators is not a list of
# (name, estimator) and ignore ValueError when the list is not
# formated correctly. This is to prevent errors when calling
# `set_params`. `BaseEstimator.set_params` calls `get_params` which
# can error for invalid values for `estimators`.
return out

for name, estimator in estimators:
if hasattr(estimator, "get_params"):
for key, value in estimator.get_params(deep=True).items():
Expand All @@ -42,14 +52,15 @@ def _set_params(self, attr, **params):
# 1. All steps
if attr in params:
setattr(self, attr, params.pop(attr))
# 2. Step replacement
# 2. Replace items with estimators in params
items = getattr(self, attr)
names = []
if items:
names, _ = zip(*items)
for name in list(params.keys()):
if "__" not in name and name in names:
self._replace_estimator(attr, name, params.pop(name))
if isinstance(items, list) and items:
# Get item names used to identify valid names in params
item_names, _ = zip(*items)
for name in list(params.keys()):
if "__" not in name and name in item_names:
self._replace_estimator(attr, name, params.pop(name))

# 3. Step parameters and other initialisation arguments
super().set_params(**params)
return self
Expand Down