diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index cfc30ed3b374b..92f35b1346926 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -562,15 +562,6 @@ for your estimator's tags. For example:: You can create a new subclass of :class:`~sklearn.utils.Tags` if you wish to add new tags to the existing set. -In addition to the tags, estimators also need to declare any non-optional -parameters to ``__init__`` in the ``_required_parameters`` class attribute, -which is a list or tuple. If ``_required_parameters`` is only -``["estimator"]`` or ``["base_estimator"]``, then the estimator will be -instantiated with an instance of ``LogisticRegression`` (or -``RidgeRegression`` if the estimator is a regressor) in the tests. The choice -of these two models is somewhat idiosyncratic but both should provide robust -closed-form solutions. - .. _developer_api_set_output: Developer API for `set_output` diff --git a/sklearn/base.py b/sklearn/base.py index 926e61cf4147f..5d770c8a0f1a4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -1037,9 +1037,12 @@ def fit_predict(self, X, y=None, **kwargs): class MetaEstimatorMixin: """Mixin class for all meta estimators in scikit-learn. - This mixin defines the following functionality: + This mixin is empty, and only exists to indicate that the estimator is a + meta-estimator. - - define `_required_parameters` that specify the mandatory `estimator` parameter. + .. versionchanged:: 1.6 + The `_required_parameters` is now removed and is unnecessary since tests are + refactored and don't use this anymore. Examples -------- @@ -1061,8 +1064,6 @@ class MetaEstimatorMixin: LogisticRegression() """ - _required_parameters = ["estimator"] - class MultiOutputMixin: """Mixin to mark estimators that support multioutput.""" diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index fea683bb54806..73a0f5e2bd8d1 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -287,8 +287,6 @@ class ColumnTransformer(TransformerMixin, _BaseComposition): :ref:`sphx_glr_auto_examples_compose_plot_column_transformer_mixed_types.py`. """ - _required_parameters = ["transformers"] - _parameter_constraints: dict = { "transformers": [list, Hidden(tuple)], "remainder": [ @@ -1322,6 +1320,21 @@ def get_metadata_routing(self): return router + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags._xfail_checks = { + "check_estimators_empty_data_messages": "FIXME", + "check_estimators_nan_inf": "FIXME", + "check_estimator_sparse_array": "FIXME", + "check_estimator_sparse_matrix": "FIXME", + "check_transformer_data_not_an_array": "FIXME", + "check_fit1d": "FIXME", + "check_fit2d_predict1d": "FIXME", + "check_complex_data": "FIXME", + "check_fit2d_1feature": "FIXME", + } + return tags + def _check_X(X): """Use check_array only when necessary, e.g. on lists and other non-array-likes.""" diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index b34507f724af1..2d916dac1bc83 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -1279,8 +1279,6 @@ class SparseCoder(_BaseSparseCoding, BaseEstimator): [ 0., 1., 1., 0., 0.]]) """ - _required_parameters = ["dictionary"] - def __init__( self, dictionary, diff --git a/sklearn/ensemble/_base.py b/sklearn/ensemble/_base.py index 1ba0b3bd005f1..2789dd234294e 100644 --- a/sklearn/ensemble/_base.py +++ b/sklearn/ensemble/_base.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABCMeta, abstractmethod -from typing import List import numpy as np from joblib import effective_n_jobs @@ -106,9 +105,6 @@ class BaseEnsemble(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): The collection of fitted base estimators. """ - # overwrite _required_parameters from MetaEstimatorMixin - _required_parameters: List[str] = [] - @abstractmethod def __init__( self, @@ -200,8 +196,6 @@ class _BaseHeterogeneousEnsemble( appear in `estimators_`. """ - _required_parameters = ["estimators"] - @property def named_estimators(self): """Dictionary to access any fitted sub-estimators by name. diff --git a/sklearn/model_selection/_classification_threshold.py b/sklearn/model_selection/_classification_threshold.py index 26a94baa33f15..3505d89e1a31a 100644 --- a/sklearn/model_selection/_classification_threshold.py +++ b/sklearn/model_selection/_classification_threshold.py @@ -87,7 +87,6 @@ class BaseThresholdClassifier(ClassifierMixin, MetaEstimatorMixin, BaseEstimator error. """ - _required_parameters = ["estimator"] _parameter_constraints: dict = { "estimator": [ HasMethods(["fit", "predict_proba"]), diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index d0f3a3ba42ef9..a4f2e603e73cc 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -1532,8 +1532,6 @@ class GridSearchCV(BaseSearchCV): 'std_fit_time', 'std_score_time', 'std_test_score'] """ - _required_parameters = ["estimator", "param_grid"] - _parameter_constraints: dict = { **BaseSearchCV._parameter_constraints, "param_grid": [dict, list], @@ -1913,8 +1911,6 @@ class RandomizedSearchCV(BaseSearchCV): {'C': np.float64(2...), 'penalty': 'l1'} """ - _required_parameters = ["estimator", "param_distributions"] - _parameter_constraints: dict = { **BaseSearchCV._parameter_constraints, "param_distributions": [dict, list], diff --git a/sklearn/model_selection/_search_successive_halving.py b/sklearn/model_selection/_search_successive_halving.py index b8d0b3068a81f..d0b0096d1d4eb 100644 --- a/sklearn/model_selection/_search_successive_halving.py +++ b/sklearn/model_selection/_search_successive_halving.py @@ -378,6 +378,9 @@ def __sklearn_tags__(self): "Fail during parameter check since min/max resources requires" " more samples" ), + "check_estimators_nan_inf": "FIXME", + "check_classifiers_one_label_sample_weights": "FIXME", + "check_fit2d_1feature": "FIXME", } ) return tags @@ -668,8 +671,6 @@ class HalvingGridSearchCV(BaseSuccessiveHalving): {'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9} """ - _required_parameters = ["estimator", "param_grid"] - _parameter_constraints: dict = { **BaseSuccessiveHalving._parameter_constraints, "param_grid": [dict, list], @@ -1018,8 +1019,6 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving): {'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9} """ - _required_parameters = ["estimator", "param_distributions"] - _parameter_constraints: dict = { **BaseSuccessiveHalving._parameter_constraints, "param_distributions": [dict, list], diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index a7ef4f2a7b6c0..41daced76c1a9 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -152,8 +152,6 @@ class Pipeline(_BaseComposition): """ # BaseEstimator interface - _required_parameters = ["steps"] - _parameter_constraints: dict = { "steps": [list, Hidden(tuple)], "memory": [None, str, HasMethods(["cache"])], @@ -1427,8 +1425,6 @@ class FeatureUnion(TransformerMixin, _BaseComposition): :ref:`sphx_glr_auto_examples_compose_plot_feature_union.py`. """ - _required_parameters = ["transformer_list"] - def __init__( self, transformer_list, @@ -1882,6 +1878,15 @@ def get_metadata_routing(self): return router + def __sklearn_tags__(self): + tags = super().__sklearn_tags__() + tags._xfail_checks = { + "check_estimators_overwrite_params": "FIXME", + "check_estimators_nan_inf": "FIXME", + "check_dont_overwrite_parameters": "FIXME", + } + return tags + def make_union(*transformers, n_jobs=None, verbose=False): """Construct a :class:`FeatureUnion` from the given transformers. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9547b565ef54d..aefc7b03e1615 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -26,14 +26,15 @@ MeanShift, SpectralClustering, ) +from sklearn.compose import ColumnTransformer from sklearn.datasets import make_blobs from sklearn.exceptions import ConvergenceWarning, FitFailedWarning + +# make it possible to discover experimental estimators when calling `all_estimators` from sklearn.experimental import ( enable_halving_search_cv, # noqa enable_iterative_imputer, # noqa ) - -# make it possible to discover experimental estimators when calling `all_estimators` from sklearn.linear_model import LogisticRegression from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding from sklearn.neighbors import ( @@ -43,7 +44,7 @@ RadiusNeighborsClassifier, RadiusNeighborsRegressor, ) -from sklearn.pipeline import make_pipeline +from sklearn.pipeline import FeatureUnion, make_pipeline from sklearn.preprocessing import ( FunctionTransformer, MinMaxScaler, @@ -54,11 +55,9 @@ from sklearn.utils import all_estimators from sklearn.utils._tags import get_tags from sklearn.utils._test_common.instance_generator import ( - _generate_column_transformer_instances, _generate_pipeline, _generate_search_cv_instances, _get_check_estimator_ids, - _set_checking_parameters, _tested_estimators, ) from sklearn.utils._testing import ( @@ -139,7 +138,6 @@ def test_estimators(estimator, check, request): with ignore_warnings( category=(FutureWarning, ConvergenceWarning, UserWarning, LinAlgWarning) ): - _set_checking_parameters(estimator) check(estimator) @@ -285,7 +283,6 @@ def check_field_types(tags, defaults): "estimator", _tested_estimators(), ids=_get_check_estimator_ids ) def test_check_n_features_in_after_fitting(estimator): - _set_checking_parameters(estimator) check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) @@ -324,7 +321,8 @@ def _estimators_that_predict_in_fit(): "estimator", column_name_estimators, ids=_get_check_estimator_ids ) def test_pandas_column_name_consistency(estimator): - _set_checking_parameters(estimator) + if isinstance(estimator, ColumnTransformer): + pytest.skip("ColumnTransformer is not tested here") with ignore_warnings(category=(FutureWarning)): with warnings.catch_warnings(record=True) as record: check_dataframe_column_names_consistency( @@ -360,7 +358,6 @@ def _include_in_get_feature_names_out_check(transformer): "transformer", GET_FEATURES_OUT_ESTIMATORS, ids=_get_check_estimator_ids ) def test_transformers_get_feature_names_out(transformer): - _set_checking_parameters(transformer) with ignore_warnings(category=(FutureWarning)): check_transformer_get_feature_names_out( @@ -381,7 +378,6 @@ def test_transformers_get_feature_names_out(transformer): ) def test_estimators_get_feature_names_out_error(estimator): estimator_name = estimator.__class__.__name__ - _set_checking_parameters(estimator) check_get_feature_names_out_error(estimator_name, estimator) @@ -409,14 +405,14 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): chain( _tested_estimators(), _generate_pipeline(), - _generate_column_transformer_instances(), _generate_search_cv_instances(), ), ids=_get_check_estimator_ids, ) def test_check_param_validation(estimator): + if isinstance(estimator, FeatureUnion): + pytest.skip("FeatureUnion is not tested here") name = estimator.__class__.__name__ - _set_checking_parameters(estimator) check_param_validation(name, estimator) @@ -481,7 +477,6 @@ def test_set_output_transform(estimator): f"Skipping check_set_output_transform for {name}: Does not support" " set_output API" ) - _set_checking_parameters(estimator) with ignore_warnings(category=(FutureWarning)): check_set_output_transform(estimator.__class__.__name__, estimator) @@ -505,7 +500,6 @@ def test_set_output_transform_configured(estimator, check_func): f"Skipping {check_func.__name__} for {name}: Does not support" " set_output API yet" ) - _set_checking_parameters(estimator) with ignore_warnings(category=(FutureWarning)): check_func(estimator.__class__.__name__, estimator) @@ -523,8 +517,6 @@ def test_check_inplace_ensure_writeable(estimator): else: raise SkipTest(f"{name} doesn't require writeable input.") - _set_checking_parameters(estimator) - # The following estimators can work inplace only with certain settings if name == "HDBSCAN": estimator.set_params(metric="precomputed", algorithm="brute") diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py index 519fdca2a865b..aff5d58a8f3a7 100644 --- a/sklearn/utils/_test_common/instance_generator.py +++ b/sklearn/utils/_test_common/instance_generator.py @@ -9,7 +9,6 @@ from itertools import product from sklearn import config_context -from sklearn.base import RegressorMixin from sklearn.calibration import CalibratedClassifierCV from sklearn.cluster import ( HDBSCAN, @@ -39,6 +38,7 @@ MiniBatchDictionaryLearning, MiniBatchNMF, MiniBatchSparsePCA, + SparseCoder, SparsePCA, TruncatedSVD, ) @@ -60,10 +60,13 @@ RandomTreesEmbedding, StackingClassifier, StackingRegressor, + VotingClassifier, + VotingRegressor, ) from sklearn.exceptions import SkipTestWarning from sklearn.experimental import enable_halving_search_cv # noqa from sklearn.feature_selection import ( + RFE, RFECV, SelectFdr, SelectFromModel, @@ -106,16 +109,27 @@ from sklearn.manifold import MDS, TSNE, LocallyLinearEmbedding, SpectralEmbedding from sklearn.mixture import BayesianGaussianMixture, GaussianMixture from sklearn.model_selection import ( + FixedThresholdClassifier, GridSearchCV, HalvingGridSearchCV, HalvingRandomSearchCV, RandomizedSearchCV, TunedThresholdClassifierCV, ) -from sklearn.multioutput import ClassifierChain, RegressorChain +from sklearn.multiclass import ( + OneVsOneClassifier, + OneVsRestClassifier, + OutputCodeClassifier, +) +from sklearn.multioutput import ( + ClassifierChain, + MultiOutputClassifier, + MultiOutputRegressor, + RegressorChain, +) from sklearn.neighbors import NeighborhoodComponentsAnalysis from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor -from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.pipeline import FeatureUnion, Pipeline, make_pipeline from sklearn.preprocessing import OneHotEncoder, StandardScaler, TargetEncoder from sklearn.random_projection import ( GaussianRandomProjection, @@ -135,7 +149,7 @@ # The following dictionary is to indicate constructor arguments suitable for the test # suite, which uses very small datasets, and is intended to run rather quickly. -TEST_PARAMS = { +INIT_PARAMS = { AdaBoostClassifier: dict(n_estimators=5), AdaBoostRegressor: dict(n_estimators=5), AffinityPropagation: dict(max_iter=5), @@ -148,9 +162,10 @@ BernoulliRBM: dict(n_iter=5, batch_size=10), Birch: dict(n_clusters=2), BisectingKMeans: dict(n_init=2, n_clusters=2, max_iter=5), - CalibratedClassifierCV: dict(cv=3), + CalibratedClassifierCV: dict(estimator=LogisticRegression(C=1), cv=3), CCA: dict(n_components=1, max_iter=5), - ClassifierChain: dict(cv=3), + ClassifierChain: dict(base_estimator=LogisticRegression(C=1), cv=3), + ColumnTransformer: dict(transformers=[("trans1", StandardScaler(), [0, 1])]), DictionaryLearning: dict(max_iter=20, transform_algorithm="lasso_lars"), # the default strategy prior would output constant predictions and fail # for check_classifiers_predictions @@ -162,6 +177,8 @@ FactorAnalysis: dict(max_iter=5), FastICA: dict(max_iter=5), FeatureAgglomeration: dict(n_clusters=2), + FeatureUnion: dict(transformer_list=[("trans1", StandardScaler())]), + FixedThresholdClassifier: dict(estimator=LogisticRegression(C=1)), GammaRegressor: dict(max_iter=5), GaussianMixture: dict(n_init=2, max_iter=5), # Due to the jl lemma and often very few samples, the number @@ -173,9 +190,25 @@ GradientBoostingRegressor: dict(n_estimators=5), GraphicalLassoCV: dict(max_iter=5, cv=3), GraphicalLasso: dict(max_iter=5), - GridSearchCV: dict(cv=3), - HalvingGridSearchCV: dict(cv=3), - HalvingRandomSearchCV: dict(cv=3), + GridSearchCV: dict( + estimator=LogisticRegression(C=1), param_grid={"C": [1.0]}, cv=3 + ), + HalvingGridSearchCV: dict( + estimator=Ridge(), + min_resources="smallest", + param_grid={"alpha": [0.1, 1.0]}, + random_state=0, + cv=2, + error_score="raise", + ), + HalvingRandomSearchCV: dict( + estimator=Ridge(), + param_distributions={"alpha": [0.1, 1.0]}, + min_resources="smallest", + cv=2, + error_score="raise", + random_state=0, + ), HDBSCAN: dict(min_samples=1), # The default min_samples_leaf (20) isn't appropriate for small # datasets (only very shallow trees are built) that the checks use. @@ -196,8 +229,8 @@ # We need to provide the noise variance explicitly. LassoLarsIC: dict(max_iter=5, noise_variance=1.0), LatentDirichletAllocation: dict(max_iter=5, batch_size=10), - LinearSVR: dict(max_iter=20), LinearSVC: dict(max_iter=20), + LinearSVR: dict(max_iter=20), LocallyLinearEmbedding: dict(max_iter=5), LogisticRegressionCV: dict(max_iter=5, cv=3), LogisticRegression: dict(max_iter=5), @@ -212,6 +245,8 @@ MiniBatchSparsePCA: dict(max_iter=5, batch_size=10), MLPClassifier: dict(max_iter=100), MLPRegressor: dict(max_iter=100), + MultiOutputClassifier: dict(estimator=LogisticRegression(C=1)), + MultiOutputRegressor: dict(estimator=Ridge()), MultiTaskElasticNetCV: dict(max_iter=5, cv=3), MultiTaskElasticNet: dict(max_iter=5), MultiTaskLassoCV: dict(max_iter=5, cv=3), @@ -222,28 +257,44 @@ NuSVR: dict(max_iter=-1), OneClassSVM: dict(max_iter=-1), OneHotEncoder: dict(handle_unknown="ignore"), + OneVsOneClassifier: dict(estimator=LogisticRegression(C=1)), + OneVsRestClassifier: dict(estimator=LogisticRegression(C=1)), OrthogonalMatchingPursuitCV: dict(cv=3), + OutputCodeClassifier: dict(estimator=LogisticRegression(C=1)), PassiveAggressiveClassifier: dict(max_iter=5), PassiveAggressiveRegressor: dict(max_iter=5), Perceptron: dict(max_iter=5), + Pipeline: dict(steps=[("scaler", StandardScaler()), ("est", Ridge())]), PLSCanonical: dict(n_components=1, max_iter=5), PLSRegression: dict(n_components=1, max_iter=5), PLSSVD: dict(n_components=1), PoissonRegressor: dict(max_iter=5), RandomForestClassifier: dict(n_estimators=5), RandomForestRegressor: dict(n_estimators=5), - RandomizedSearchCV: dict(n_iter=5, cv=3), + RandomizedSearchCV: dict( + estimator=LogisticRegression(C=1), + param_distributions={"C": [1.0]}, + n_iter=5, + cv=3, + ), RandomTreesEmbedding: dict(n_estimators=5), - RANSACRegressor: dict(max_trials=10), - RegressorChain: dict(cv=3), - RFECV: dict(cv=3), + # `RANSACRegressor` will raise an error with any model other + # than `LinearRegression` if we don't fix the `min_samples` parameter. + # For common tests, we can enforce using `LinearRegression` that + # is the default estimator in `RANSACRegressor` instead of `Ridge`. + RANSACRegressor: dict(estimator=LinearRegression(), max_trials=10), + RegressorChain: dict(base_estimator=Ridge(), cv=3), + RFECV: dict(estimator=LogisticRegression(C=1), cv=3), + RFE: dict(estimator=LogisticRegression(C=1)), # be tolerant of noisy datasets (not actually speed) SelectFdr: dict(alpha=0.5), + # Increases coverage because SGDRegressor has partial_fit + SelectFromModel: dict(estimator=SGDRegressor(random_state=0)), # SelectKBest has a default of k=10 # which is more feature than we have in most case. SelectKBest: dict(k=1), - SelfTrainingClassifier: dict(max_iter=5), - SequentialFeatureSelector: dict(cv=3), + SelfTrainingClassifier: dict(estimator=LogisticRegression(C=1), max_iter=5), + SequentialFeatureSelector: dict(estimator=LogisticRegression(C=1), cv=3), SGDClassifier: dict(max_iter=5), SGDOneClassSVM: dict(max_iter=5), SGDRegressor: dict(max_iter=5), @@ -258,9 +309,21 @@ SpectralCoclustering: dict(n_init=2, n_clusters=2), # Default "auto" parameter can lead to different ordering of eigenvalues on # windows: #24105 - SpectralEmbedding: dict(eigen_tol=1e-5), - StackingClassifier: dict(cv=3), - StackingRegressor: dict(cv=3), + SpectralEmbedding: dict(eigen_tol=1e-05), + StackingClassifier: dict( + estimators=[ + ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), + ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), + ], + cv=3, + ), + StackingRegressor: dict( + estimators=[ + ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), + ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), + ], + cv=3, + ), SVC: dict(max_iter=-1), SVR: dict(max_iter=-1), TargetEncoder: dict(cv=3), @@ -268,19 +331,23 @@ # TruncatedSVD doesn't run with n_components = n_features TruncatedSVD: dict(n_iter=5, n_components=1), TSNE: dict(perplexity=2), - TunedThresholdClassifierCV: dict(cv=3), + TunedThresholdClassifierCV: dict(estimator=LogisticRegression(C=1), cv=3), TweedieRegressor: dict(max_iter=5), + VotingClassifier: dict( + estimators=[ + ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), + ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), + ] + ), + VotingRegressor: dict( + estimators=[ + ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), + ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), + ] + ), } -def _set_checking_parameters(estimator): - """Set the parameters of an estimator instance to speed-up tests and avoid - deprecation warnings in common test.""" - if type(estimator) in TEST_PARAMS: - test_params = TEST_PARAMS[type(estimator)] - estimator.set_params(**test_params) - - def _tested_estimators(type_filter=None): for name, Estimator in all_estimators(type_filter=type_filter): try: @@ -304,48 +371,19 @@ def _generate_pipeline(): ) +SKIPPED_ESTIMATORS = [SparseCoder] + + def _construct_instance(Estimator): """Construct Estimator instance if possible.""" - required_parameters = getattr(Estimator, "_required_parameters", []) - if len(required_parameters): - if required_parameters in (["estimator"], ["base_estimator"]): - # `RANSACRegressor` will raise an error with any model other - # than `LinearRegression` if we don't fix `min_samples` parameter. - # For common test, we can enforce using `LinearRegression` that - # is the default estimator in `RANSACRegressor` instead of `Ridge`. - if issubclass(Estimator, RANSACRegressor): - estimator = Estimator(LinearRegression()) - elif issubclass(Estimator, RegressorMixin): - estimator = Estimator(Ridge()) - elif issubclass(Estimator, SelectFromModel): - # Increases coverage because SGDRegressor has partial_fit - estimator = Estimator(SGDRegressor(random_state=0)) - else: - estimator = Estimator(LogisticRegression(C=1)) - elif required_parameters in (["estimators"],): - # Heterogeneous ensemble classes (i.e. stacking, voting) - if issubclass(Estimator, RegressorMixin): - estimator = Estimator( - estimators=[ - ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), - ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), - ] - ) - else: - estimator = Estimator( - estimators=[ - ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), - ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), - ] - ) - else: - msg = ( - f"Can't instantiate estimator {Estimator.__name__} " - f"parameters {required_parameters}" - ) - # raise additional warning to be shown by pytest - warnings.warn(msg, SkipTestWarning) - raise SkipTest(msg) + if Estimator in SKIPPED_ESTIMATORS: + msg = f"Can't instantiate estimator {Estimator.__name__}" + # raise additional warning to be shown by pytest + warnings.warn(msg, SkipTestWarning) + raise SkipTest(msg) + + if Estimator in INIT_PARAMS: + estimator = Estimator(**INIT_PARAMS[Estimator]) else: estimator = Estimator() return estimator @@ -387,16 +425,6 @@ def _get_check_estimator_ids(obj): return re.sub(r"\s", "", str(obj)) -def _generate_column_transformer_instances(): - """Generate a `ColumnTransformer` instance to check its compliance with - scikit-learn.""" - yield ColumnTransformer( - transformers=[ - ("trans1", StandardScaler(), [0, 1]), - ] - ) - - def _generate_search_cv_instances(): """Generator of `SearchCV` instances to check their compliance with scikit-learn.""" for SearchCV, (Estimator, param_grid) in product( diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 22557c45cdead..4680c987b4b3a 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -58,6 +58,7 @@ from ._tags import Tags, get_tags from ._test_common.instance_generator import ( CROSS_DECOMPOSITION, + INIT_PARAMS, _construct_instance, _get_check_estimator_ids, ) @@ -82,6 +83,8 @@ def _yield_api_checks(estimator): + yield check_estimator_cloneable + yield check_estimator_repr yield check_no_attributes_set_in_init yield check_fit_score_takes_y yield check_estimators_overwrite_params @@ -3256,18 +3259,31 @@ def check_estimators_data_not_an_array(name, estimator_orig, X, y, obj_type): assert_allclose(pred1, pred2, atol=1e-2, err_msg=name) -def check_parameters_default_constructible(name, Estimator): +def check_estimator_cloneable(name, estimator_orig): + """Checks whether the estimator can be cloned.""" + try: + clone(estimator_orig) + except Exception as e: + raise AssertionError(f"Cloning of {name} failed with error: {e}.") from e + + +def check_estimator_repr(name, estimator_orig): + """Check that the estimator has a functioning repr.""" + estimator = clone(estimator_orig) + try: + repr(estimator) + except Exception as e: + raise AssertionError(f"Repr of {name} failed with error: {e}.") from e + + +def check_parameters_default_constructible(name, estimator_orig): # test default-constructibility # get rid of deprecation warnings - Estimator = Estimator.__class__ + Estimator = estimator_orig.__class__ with ignore_warnings(category=FutureWarning): estimator = _construct_instance(Estimator) - # test cloning - clone(estimator) - # test __repr__ - repr(estimator) # test that set_params returns self assert estimator.set_params() is estimator @@ -3287,6 +3303,8 @@ def param_filter(p): p.name != "self" and p.kind != p.VAR_KEYWORD and p.kind != p.VAR_POSITIONAL + # and it should have a default value for this test + and p.default != p.empty ) init_params = [ @@ -3298,10 +3316,15 @@ def param_filter(p): # true for mixins return params = estimator.get_params() - # they can need a non-default argument - init_params = init_params[len(getattr(estimator, "_required_parameters", [])) :] for init_param in init_params: + if ( + type(estimator) in INIT_PARAMS + and init_param.name in INIT_PARAMS[type(estimator)] + ): + # these parameters are coming from INIT_PARAMS and not the default + # values, therefore ignored. + continue assert ( init_param.default != init_param.empty ), "parameter %s for %s has no default value" % ( diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index c5533a4e514ca..34e549ba143a9 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -31,7 +31,6 @@ from sklearn.utils import _array_api, all_estimators, deprecated from sklearn.utils._param_validation import Interval, StrOptions from sklearn.utils._tags import default_tags -from sklearn.utils._test_common.instance_generator import _set_checking_parameters from sklearn.utils._testing import ( MinimalClassifier, MinimalRegressor, @@ -52,6 +51,8 @@ check_dataframe_column_names_consistency, check_decision_proba_consistency, check_estimator, + check_estimator_cloneable, + check_estimator_repr, check_estimators_unfitted, check_fit_check_is_fitted, check_fit_score_takes_y, @@ -752,7 +753,6 @@ def test_check_estimator_clones(): # without fitting with ignore_warnings(category=ConvergenceWarning): est = Estimator() - _set_checking_parameters(est) set_random_state(est) old_hash = joblib.hash(est) check_estimator(est) @@ -761,7 +761,6 @@ def test_check_estimator_clones(): # with fitting with ignore_warnings(category=ConvergenceWarning): est = Estimator() - _set_checking_parameters(est) set_random_state(est) est.fit(iris.data + 10, iris.target) old_hash = joblib.hash(est) @@ -1303,3 +1302,29 @@ def get_check_name(check): non_legacy_check_names = {get_check_name(check) for check in non_legacy_checks} legacy_check_names = {get_check_name(check) for check in legacy_checks} assert non_legacy_check_names.issubset(legacy_check_names) + + +def test_check_estimator_cloneable_error(): + """Check that the right error is raised when the estimator is not cloneable.""" + + class NotCloneable(BaseEstimator): + def __sklearn_clone__(self): + raise NotImplementedError("This estimator is not cloneable.") + + estimator = NotCloneable() + msg = "Cloning of .* failed with error" + with raises(AssertionError, match=msg): + check_estimator_cloneable("NotCloneable", estimator) + + +def test_estimator_repr_error(): + """Check that the right error is raised when the estimator does not have a repr.""" + + class NotRepr(BaseEstimator): + def __repr__(self): + raise NotImplementedError("This estimator does not have a repr.") + + estimator = NotRepr() + msg = "Repr of .* failed with error" + with raises(AssertionError, match=msg): + check_estimator_repr("NotRepr", estimator)