diff --git a/doc/sphinxext/allow_nan_estimators.py b/doc/sphinxext/allow_nan_estimators.py index 89d7077bce2b5..00a6ddc0048e9 100755 --- a/doc/sphinxext/allow_nan_estimators.py +++ b/doc/sphinxext/allow_nan_estimators.py @@ -4,8 +4,8 @@ from docutils.parsers.rst import Directive from sklearn.utils import all_estimators +from sklearn.utils._test_common.instance_generator import _construct_instance from sklearn.utils._testing import SkipTest -from sklearn.utils.estimator_checks import _construct_instance class AllowNanEstimators(Directive): diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index bd7f60061abdc..52f769bfb9001 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -17,9 +17,9 @@ yield_namespace_device_dtype_combinations, ) from sklearn.utils._array_api import device as array_device +from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids from sklearn.utils._testing import _array_api_for_tests, assert_allclose from sklearn.utils.estimator_checks import ( - _get_check_estimator_ids, check_array_api_input_and_values, ) from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index 8978ab40b072f..008ccf11d6ac3 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -48,6 +48,7 @@ yield_namespace_device_dtype_combinations, yield_namespaces, ) +from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids from sklearn.utils._testing import ( assert_allclose, assert_almost_equal, @@ -57,7 +58,6 @@ ) from sklearn.utils.estimator_checks import ( _array_api_for_tests, - _get_check_estimator_ids, check_array_api_input_and_values, ) from sklearn.utils.fixes import ( diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 7273644557fcd..1b3d959351634 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -40,6 +40,7 @@ from sklearn.utils._array_api import ( yield_namespace_device_dtype_combinations, ) +from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids from sklearn.utils._testing import ( _convert_container, assert_allclose, @@ -51,7 +52,6 @@ skip_if_32bit, ) from sklearn.utils.estimator_checks import ( - _get_check_estimator_ids, check_array_api_input_and_values, ) from sklearn.utils.fixes import ( diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 70976a3f6acb8..f9d0fc8983a4c 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -11,7 +11,7 @@ import warnings from functools import partial from inspect import isgenerator, signature -from itertools import chain, product +from itertools import chain import numpy as np import pytest @@ -26,24 +26,16 @@ MeanShift, SpectralClustering, ) -from sklearn.compose import ColumnTransformer from sklearn.datasets import make_blobs -from sklearn.decomposition import PCA from sklearn.exceptions import ConvergenceWarning, FitFailedWarning - -# make it possible to discover experimental estimators when calling `all_estimators` from sklearn.experimental import ( enable_halving_search_cv, # noqa enable_iterative_imputer, # noqa ) -from sklearn.linear_model import LogisticRegression, Ridge + +# make it possible to discover experimental estimators when calling `all_estimators` +from sklearn.linear_model import LogisticRegression from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding -from sklearn.model_selection import ( - GridSearchCV, - HalvingGridSearchCV, - HalvingRandomSearchCV, - RandomizedSearchCV, -) from sklearn.neighbors import ( KNeighborsClassifier, KNeighborsRegressor, @@ -51,7 +43,7 @@ RadiusNeighborsClassifier, RadiusNeighborsRegressor, ) -from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.pipeline import make_pipeline from sklearn.preprocessing import ( FunctionTransformer, MinMaxScaler, @@ -61,15 +53,19 @@ from sklearn.semi_supervised import LabelPropagation, LabelSpreading from sklearn.utils import all_estimators from sklearn.utils._tags import _DEFAULT_TAGS, _safe_tags +from sklearn.utils._test_common.instance_generator import ( + _generate_column_transformer_instances, + _generate_pipeline, + _generate_search_cv_instances, + _get_check_estimator_ids, + _set_checking_parameters, + _tested_estimators, +) from sklearn.utils._testing import ( SkipTest, ignore_warnings, - set_random_state, ) from sklearn.utils.estimator_checks import ( - _construct_instance, - _get_check_estimator_ids, - _set_checking_parameters, check_dataframe_column_names_consistency, check_estimator, check_get_feature_names_out_error, @@ -137,26 +133,6 @@ def test_get_check_estimator_ids(val, expected): assert _get_check_estimator_ids(val) == expected -def _tested_estimators(type_filter=None): - for name, Estimator in all_estimators(type_filter=type_filter): - try: - estimator = _construct_instance(Estimator) - except SkipTest: - continue - - yield estimator - - -def _generate_pipeline(): - for final_estimator in [Ridge(), LogisticRegression()]: - yield Pipeline( - steps=[ - ("scaler", StandardScaler()), - ("final_estimator", final_estimator), - ] - ) - - @parametrize_with_checks(list(chain(_tested_estimators(), _generate_pipeline()))) def test_estimators(estimator, check, request): # Common tests for estimator instances @@ -259,60 +235,6 @@ def test_class_support_removed(): parametrize_with_checks([LogisticRegression]) -def _generate_column_transformer_instances(): - yield ColumnTransformer( - transformers=[ - ("trans1", StandardScaler(), [0, 1]), - ] - ) - - -def _generate_search_cv_instances(): - for SearchCV, (Estimator, param_grid) in product( - [ - GridSearchCV, - HalvingGridSearchCV, - RandomizedSearchCV, - HalvingGridSearchCV, - ], - [ - (Ridge, {"alpha": [0.1, 1.0]}), - (LogisticRegression, {"C": [0.1, 1.0]}), - ], - ): - init_params = signature(SearchCV).parameters - extra_params = ( - {"min_resources": "smallest"} if "min_resources" in init_params else {} - ) - search_cv = SearchCV( - Estimator(), param_grid, cv=2, error_score="raise", **extra_params - ) - set_random_state(search_cv) - yield search_cv - - for SearchCV, (Estimator, param_grid) in product( - [ - GridSearchCV, - HalvingGridSearchCV, - RandomizedSearchCV, - HalvingRandomSearchCV, - ], - [ - (Ridge, {"ridge__alpha": [0.1, 1.0]}), - (LogisticRegression, {"logisticregression__C": [0.1, 1.0]}), - ], - ): - init_params = signature(SearchCV).parameters - extra_params = ( - {"min_resources": "smallest"} if "min_resources" in init_params else {} - ) - search_cv = SearchCV( - make_pipeline(PCA(), Estimator()), param_grid, cv=2, **extra_params - ).set_params(error_score="raise") - set_random_state(search_cv) - yield search_cv - - @parametrize_with_checks(list(_generate_search_cv_instances())) def test_search_cv(estimator, check, request): # Common tests for SearchCV instances diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 3af463b783bc3..687b85ed00187 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -22,6 +22,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import FunctionTransformer from sklearn.utils import all_estimators +from sklearn.utils._test_common.instance_generator import _construct_instance from sklearn.utils._testing import ( _get_func_name, check_docstring_parameters, @@ -29,7 +30,6 @@ ) from sklearn.utils.deprecation import _is_deprecated from sklearn.utils.estimator_checks import ( - _construct_instance, _enforce_estimator_tags_X, _enforce_estimator_tags_y, ) diff --git a/sklearn/utils/_test_common/__init__.py b/sklearn/utils/_test_common/__init__.py new file mode 100644 index 0000000000000..67dd18fb94b59 --- /dev/null +++ b/sklearn/utils/_test_common/__init__.py @@ -0,0 +1,2 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py new file mode 100644 index 0000000000000..519fdca2a865b --- /dev/null +++ b/sklearn/utils/_test_common/instance_generator.py @@ -0,0 +1,444 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + + +import re +import warnings +from functools import partial +from inspect import isfunction, signature +from itertools import product + +from sklearn import config_context +from sklearn.base import RegressorMixin +from sklearn.calibration import CalibratedClassifierCV +from sklearn.cluster import ( + HDBSCAN, + AffinityPropagation, + AgglomerativeClustering, + Birch, + BisectingKMeans, + FeatureAgglomeration, + KMeans, + MeanShift, + MiniBatchKMeans, + SpectralBiclustering, + SpectralClustering, + SpectralCoclustering, +) +from sklearn.compose import ColumnTransformer +from sklearn.covariance import GraphicalLasso, GraphicalLassoCV +from sklearn.cross_decomposition import CCA, PLSSVD, PLSCanonical, PLSRegression +from sklearn.decomposition import ( + NMF, + PCA, + DictionaryLearning, + FactorAnalysis, + FastICA, + IncrementalPCA, + LatentDirichletAllocation, + MiniBatchDictionaryLearning, + MiniBatchNMF, + MiniBatchSparsePCA, + SparsePCA, + TruncatedSVD, +) +from sklearn.dummy import DummyClassifier +from sklearn.ensemble import ( + AdaBoostClassifier, + AdaBoostRegressor, + BaggingClassifier, + BaggingRegressor, + ExtraTreesClassifier, + ExtraTreesRegressor, + GradientBoostingClassifier, + GradientBoostingRegressor, + HistGradientBoostingClassifier, + HistGradientBoostingRegressor, + IsolationForest, + RandomForestClassifier, + RandomForestRegressor, + RandomTreesEmbedding, + StackingClassifier, + StackingRegressor, +) +from sklearn.exceptions import SkipTestWarning +from sklearn.experimental import enable_halving_search_cv # noqa +from sklearn.feature_selection import ( + RFECV, + SelectFdr, + SelectFromModel, + SelectKBest, + SequentialFeatureSelector, +) +from sklearn.linear_model import ( + ARDRegression, + BayesianRidge, + ElasticNet, + ElasticNetCV, + GammaRegressor, + HuberRegressor, + LarsCV, + Lasso, + LassoCV, + LassoLars, + LassoLarsCV, + LassoLarsIC, + LinearRegression, + LogisticRegression, + LogisticRegressionCV, + MultiTaskElasticNet, + MultiTaskElasticNetCV, + MultiTaskLasso, + MultiTaskLassoCV, + OrthogonalMatchingPursuitCV, + PassiveAggressiveClassifier, + PassiveAggressiveRegressor, + Perceptron, + PoissonRegressor, + RANSACRegressor, + Ridge, + SGDClassifier, + SGDOneClassSVM, + SGDRegressor, + TheilSenRegressor, + TweedieRegressor, +) +from sklearn.manifold import MDS, TSNE, LocallyLinearEmbedding, SpectralEmbedding +from sklearn.mixture import BayesianGaussianMixture, GaussianMixture +from sklearn.model_selection import ( + GridSearchCV, + HalvingGridSearchCV, + HalvingRandomSearchCV, + RandomizedSearchCV, + TunedThresholdClassifierCV, +) +from sklearn.multioutput import ClassifierChain, RegressorChain +from sklearn.neighbors import NeighborhoodComponentsAnalysis +from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import OneHotEncoder, StandardScaler, TargetEncoder +from sklearn.random_projection import ( + GaussianRandomProjection, + SparseRandomProjection, +) +from sklearn.semi_supervised import ( + LabelPropagation, + LabelSpreading, + SelfTrainingClassifier, +) +from sklearn.svm import SVC, SVR, LinearSVC, LinearSVR, NuSVC, NuSVR, OneClassSVM +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.utils import all_estimators +from sklearn.utils._testing import SkipTest, set_random_state + +CROSS_DECOMPOSITION = ["PLSCanonical", "PLSRegression", "CCA", "PLSSVD"] + +# The following dictionary is to indicate constructor arguments suitable for the test +# suite, which uses very small datasets, and is intended to run rather quickly. +TEST_PARAMS = { + AdaBoostClassifier: dict(n_estimators=5), + AdaBoostRegressor: dict(n_estimators=5), + AffinityPropagation: dict(max_iter=5), + AgglomerativeClustering: dict(n_clusters=2), + ARDRegression: dict(max_iter=5), + BaggingClassifier: dict(n_estimators=5), + BaggingRegressor: dict(n_estimators=5), + BayesianGaussianMixture: dict(n_init=2, max_iter=5), + BayesianRidge: dict(max_iter=5), + BernoulliRBM: dict(n_iter=5, batch_size=10), + Birch: dict(n_clusters=2), + BisectingKMeans: dict(n_init=2, n_clusters=2, max_iter=5), + CalibratedClassifierCV: dict(cv=3), + CCA: dict(n_components=1, max_iter=5), + ClassifierChain: dict(cv=3), + DictionaryLearning: dict(max_iter=20, transform_algorithm="lasso_lars"), + # the default strategy prior would output constant predictions and fail + # for check_classifiers_predictions + DummyClassifier: dict(strategy="stratified"), + ElasticNetCV: dict(max_iter=5, cv=3), + ElasticNet: dict(max_iter=5), + ExtraTreesClassifier: dict(n_estimators=5), + ExtraTreesRegressor: dict(n_estimators=5), + FactorAnalysis: dict(max_iter=5), + FastICA: dict(max_iter=5), + FeatureAgglomeration: dict(n_clusters=2), + GammaRegressor: dict(max_iter=5), + GaussianMixture: dict(n_init=2, max_iter=5), + # Due to the jl lemma and often very few samples, the number + # of components of the random matrix projection will be probably + # greater than the number of features. + # So we impose a smaller number (avoid "auto" mode) + GaussianRandomProjection: dict(n_components=2), + GradientBoostingClassifier: dict(n_estimators=5), + GradientBoostingRegressor: dict(n_estimators=5), + GraphicalLassoCV: dict(max_iter=5, cv=3), + GraphicalLasso: dict(max_iter=5), + GridSearchCV: dict(cv=3), + HalvingGridSearchCV: dict(cv=3), + HalvingRandomSearchCV: dict(cv=3), + HDBSCAN: dict(min_samples=1), + # The default min_samples_leaf (20) isn't appropriate for small + # datasets (only very shallow trees are built) that the checks use. + HistGradientBoostingClassifier: dict(max_iter=5, min_samples_leaf=5), + HistGradientBoostingRegressor: dict(max_iter=5, min_samples_leaf=5), + HuberRegressor: dict(max_iter=5), + IncrementalPCA: dict(batch_size=10), + IsolationForest: dict(n_estimators=5), + KMeans: dict(n_init=2, n_clusters=2, max_iter=5), + LabelPropagation: dict(max_iter=5), + LabelSpreading: dict(max_iter=5), + LarsCV: dict(max_iter=5, cv=3), + LassoCV: dict(max_iter=5, cv=3), + Lasso: dict(max_iter=5), + LassoLarsCV: dict(max_iter=5, cv=3), + LassoLars: dict(max_iter=5), + # Noise variance estimation does not work when `n_samples < n_features`. + # We need to provide the noise variance explicitly. + LassoLarsIC: dict(max_iter=5, noise_variance=1.0), + LatentDirichletAllocation: dict(max_iter=5, batch_size=10), + LinearSVR: dict(max_iter=20), + LinearSVC: dict(max_iter=20), + LocallyLinearEmbedding: dict(max_iter=5), + LogisticRegressionCV: dict(max_iter=5, cv=3), + LogisticRegression: dict(max_iter=5), + MDS: dict(n_init=2, max_iter=5), + # In the case of check_fit2d_1sample, bandwidth is set to None and + # is thus estimated. De facto it is 0.0 as a single sample is provided + # and this makes the test fails. Hence we give it a placeholder value. + MeanShift: dict(max_iter=5, bandwidth=1.0), + MiniBatchDictionaryLearning: dict(batch_size=10, max_iter=5), + MiniBatchKMeans: dict(n_init=2, n_clusters=2, max_iter=5, batch_size=10), + MiniBatchNMF: dict(batch_size=10, max_iter=20, fresh_restarts=True), + MiniBatchSparsePCA: dict(max_iter=5, batch_size=10), + MLPClassifier: dict(max_iter=100), + MLPRegressor: dict(max_iter=100), + MultiTaskElasticNetCV: dict(max_iter=5, cv=3), + MultiTaskElasticNet: dict(max_iter=5), + MultiTaskLassoCV: dict(max_iter=5, cv=3), + MultiTaskLasso: dict(max_iter=5), + NeighborhoodComponentsAnalysis: dict(max_iter=5), + NMF: dict(max_iter=500), + NuSVC: dict(max_iter=-1), + NuSVR: dict(max_iter=-1), + OneClassSVM: dict(max_iter=-1), + OneHotEncoder: dict(handle_unknown="ignore"), + OrthogonalMatchingPursuitCV: dict(cv=3), + PassiveAggressiveClassifier: dict(max_iter=5), + PassiveAggressiveRegressor: dict(max_iter=5), + Perceptron: dict(max_iter=5), + PLSCanonical: dict(n_components=1, max_iter=5), + PLSRegression: dict(n_components=1, max_iter=5), + PLSSVD: dict(n_components=1), + PoissonRegressor: dict(max_iter=5), + RandomForestClassifier: dict(n_estimators=5), + RandomForestRegressor: dict(n_estimators=5), + RandomizedSearchCV: dict(n_iter=5, cv=3), + RandomTreesEmbedding: dict(n_estimators=5), + RANSACRegressor: dict(max_trials=10), + RegressorChain: dict(cv=3), + RFECV: dict(cv=3), + # be tolerant of noisy datasets (not actually speed) + SelectFdr: dict(alpha=0.5), + # SelectKBest has a default of k=10 + # which is more feature than we have in most case. + SelectKBest: dict(k=1), + SelfTrainingClassifier: dict(max_iter=5), + SequentialFeatureSelector: dict(cv=3), + SGDClassifier: dict(max_iter=5), + SGDOneClassSVM: dict(max_iter=5), + SGDRegressor: dict(max_iter=5), + SparsePCA: dict(max_iter=5), + # Due to the jl lemma and often very few samples, the number + # of components of the random matrix projection will be probably + # greater than the number of features. + # So we impose a smaller number (avoid "auto" mode) + SparseRandomProjection: dict(n_components=2), + SpectralBiclustering: dict(n_init=2, n_best=1, n_clusters=2), + SpectralClustering: dict(n_init=2, n_clusters=2), + SpectralCoclustering: dict(n_init=2, n_clusters=2), + # Default "auto" parameter can lead to different ordering of eigenvalues on + # windows: #24105 + SpectralEmbedding: dict(eigen_tol=1e-5), + StackingClassifier: dict(cv=3), + StackingRegressor: dict(cv=3), + SVC: dict(max_iter=-1), + SVR: dict(max_iter=-1), + TargetEncoder: dict(cv=3), + TheilSenRegressor: dict(max_iter=5, max_subpopulation=100), + # TruncatedSVD doesn't run with n_components = n_features + TruncatedSVD: dict(n_iter=5, n_components=1), + TSNE: dict(perplexity=2), + TunedThresholdClassifierCV: dict(cv=3), + TweedieRegressor: dict(max_iter=5), +} + + +def _set_checking_parameters(estimator): + """Set the parameters of an estimator instance to speed-up tests and avoid + deprecation warnings in common test.""" + if type(estimator) in TEST_PARAMS: + test_params = TEST_PARAMS[type(estimator)] + estimator.set_params(**test_params) + + +def _tested_estimators(type_filter=None): + for name, Estimator in all_estimators(type_filter=type_filter): + try: + estimator = _construct_instance(Estimator) + except SkipTest: + continue + + yield estimator + + +def _generate_pipeline(): + """Generator of simple pipeline to check compliance of the + :class:`~sklearn.pipeline.Pipeline` class. + """ + for final_estimator in [Ridge(), LogisticRegression()]: + yield Pipeline( + steps=[ + ("scaler", StandardScaler()), + ("final_estimator", final_estimator), + ] + ) + + +def _construct_instance(Estimator): + """Construct Estimator instance if possible.""" + required_parameters = getattr(Estimator, "_required_parameters", []) + if len(required_parameters): + if required_parameters in (["estimator"], ["base_estimator"]): + # `RANSACRegressor` will raise an error with any model other + # than `LinearRegression` if we don't fix `min_samples` parameter. + # For common test, we can enforce using `LinearRegression` that + # is the default estimator in `RANSACRegressor` instead of `Ridge`. + if issubclass(Estimator, RANSACRegressor): + estimator = Estimator(LinearRegression()) + elif issubclass(Estimator, RegressorMixin): + estimator = Estimator(Ridge()) + elif issubclass(Estimator, SelectFromModel): + # Increases coverage because SGDRegressor has partial_fit + estimator = Estimator(SGDRegressor(random_state=0)) + else: + estimator = Estimator(LogisticRegression(C=1)) + elif required_parameters in (["estimators"],): + # Heterogeneous ensemble classes (i.e. stacking, voting) + if issubclass(Estimator, RegressorMixin): + estimator = Estimator( + estimators=[ + ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), + ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), + ] + ) + else: + estimator = Estimator( + estimators=[ + ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), + ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), + ] + ) + else: + msg = ( + f"Can't instantiate estimator {Estimator.__name__} " + f"parameters {required_parameters}" + ) + # raise additional warning to be shown by pytest + warnings.warn(msg, SkipTestWarning) + raise SkipTest(msg) + else: + estimator = Estimator() + return estimator + + +def _get_check_estimator_ids(obj): + """Create pytest ids for checks. + + When `obj` is an estimator, this returns the pprint version of the + estimator (with `print_changed_only=True`). When `obj` is a function, the + name of the function is returned with its keyword arguments. + + `_get_check_estimator_ids` is designed to be used as the `id` in + `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)` + is yielding estimators and checks. + + Parameters + ---------- + obj : estimator or function + Items generated by `check_estimator`. + + Returns + ------- + id : str or None + + See Also + -------- + check_estimator + """ + if isfunction(obj): + return obj.__name__ + if isinstance(obj, partial): + if not obj.keywords: + return obj.func.__name__ + kwstring = ",".join(["{}={}".format(k, v) for k, v in obj.keywords.items()]) + return "{}({})".format(obj.func.__name__, kwstring) + if hasattr(obj, "get_params"): + with config_context(print_changed_only=True): + return re.sub(r"\s", "", str(obj)) + + +def _generate_column_transformer_instances(): + """Generate a `ColumnTransformer` instance to check its compliance with + scikit-learn.""" + yield ColumnTransformer( + transformers=[ + ("trans1", StandardScaler(), [0, 1]), + ] + ) + + +def _generate_search_cv_instances(): + """Generator of `SearchCV` instances to check their compliance with scikit-learn.""" + for SearchCV, (Estimator, param_grid) in product( + [ + GridSearchCV, + HalvingGridSearchCV, + RandomizedSearchCV, + HalvingGridSearchCV, + ], + [ + (Ridge, {"alpha": [0.1, 1.0]}), + (LogisticRegression, {"C": [0.1, 1.0]}), + ], + ): + init_params = signature(SearchCV).parameters + extra_params = ( + {"min_resources": "smallest"} if "min_resources" in init_params else {} + ) + search_cv = SearchCV( + Estimator(), param_grid, cv=2, error_score="raise", **extra_params + ) + set_random_state(search_cv) + yield search_cv + + for SearchCV, (Estimator, param_grid) in product( + [ + GridSearchCV, + HalvingGridSearchCV, + RandomizedSearchCV, + HalvingRandomSearchCV, + ], + [ + (Ridge, {"ridge__alpha": [0.1, 1.0]}), + (LogisticRegression, {"logisticregression__C": [0.1, 1.0]}), + ], + ): + init_params = signature(SearchCV).parameters + extra_params = ( + {"min_resources": "smallest"} if "min_resources" in init_params else {} + ) + search_cv = SearchCV( + make_pipeline(PCA(), Estimator()), param_grid, cv=2, **extra_params + ).set_params(error_score="raise") + set_random_state(search_cv) + yield search_cv diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c334a73ab19f1..dae5a2f3170a0 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -9,7 +9,7 @@ from contextlib import nullcontext from copy import deepcopy from functools import partial, wraps -from inspect import isfunction, signature +from inspect import signature from numbers import Integral, Real import joblib @@ -20,7 +20,6 @@ from .. import config_context from ..base import ( ClusterMixin, - RegressorMixin, clone, is_classifier, is_outlier_detector, @@ -34,14 +33,6 @@ make_regression, ) from ..exceptions import DataConversionWarning, NotFittedError, SkipTestWarning -from ..feature_selection import SelectFromModel, SelectKBest -from ..linear_model import ( - LinearRegression, - LogisticRegression, - RANSACRegressor, - Ridge, - SGDRegressor, -) from ..linear_model._base import LinearClassifierMixin from ..metrics import accuracy_score, adjusted_rand_score, f1_score from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel @@ -49,8 +40,6 @@ from ..model_selection._validation import _safe_split from ..pipeline import make_pipeline from ..preprocessing import StandardScaler, scale -from ..random_projection import BaseRandomProjection -from ..tree import DecisionTreeClassifier, DecisionTreeRegressor from ..utils._array_api import ( _atol_for_type, _convert_to_numpy, @@ -70,6 +59,11 @@ _DEFAULT_TAGS, _safe_tags, ) +from ._test_common.instance_generator import ( + CROSS_DECOMPOSITION, + _construct_instance, + _get_check_estimator_ids, +) from ._testing import ( SkipTest, _array_api_for_tests, @@ -88,7 +82,6 @@ from .validation import _num_samples, check_is_fitted, has_fit_parameter REGRESSION_DATASET = None -CROSS_DECOMPOSITION = ["PLSCanonical", "PLSRegression", "CCA", "PLSSVD"] def _yield_checks(estimator): @@ -390,89 +383,6 @@ def _yield_all_checks(estimator): yield check_fit_non_negative -def _get_check_estimator_ids(obj): - """Create pytest ids for checks. - - When `obj` is an estimator, this returns the pprint version of the - estimator (with `print_changed_only=True`). When `obj` is a function, the - name of the function is returned with its keyword arguments. - - `_get_check_estimator_ids` is designed to be used as the `id` in - `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)` - is yielding estimators and checks. - - Parameters - ---------- - obj : estimator or function - Items generated by `check_estimator`. - - Returns - ------- - id : str or None - - See Also - -------- - check_estimator - """ - if isfunction(obj): - return obj.__name__ - if isinstance(obj, partial): - if not obj.keywords: - return obj.func.__name__ - kwstring = ",".join(["{}={}".format(k, v) for k, v in obj.keywords.items()]) - return "{}({})".format(obj.func.__name__, kwstring) - if hasattr(obj, "get_params"): - with config_context(print_changed_only=True): - return re.sub(r"\s", "", str(obj)) - - -def _construct_instance(Estimator): - """Construct Estimator instance if possible.""" - required_parameters = getattr(Estimator, "_required_parameters", []) - if len(required_parameters): - if required_parameters in (["estimator"], ["base_estimator"]): - # `RANSACRegressor` will raise an error with any model other - # than `LinearRegression` if we don't fix `min_samples` parameter. - # For common test, we can enforce using `LinearRegression` that - # is the default estimator in `RANSACRegressor` instead of `Ridge`. - if issubclass(Estimator, RANSACRegressor): - estimator = Estimator(LinearRegression()) - elif issubclass(Estimator, RegressorMixin): - estimator = Estimator(Ridge()) - elif issubclass(Estimator, SelectFromModel): - # Increases coverage because SGDRegressor has partial_fit - estimator = Estimator(SGDRegressor(random_state=0)) - else: - estimator = Estimator(LogisticRegression(C=1)) - elif required_parameters in (["estimators"],): - # Heterogeneous ensemble classes (i.e. stacking, voting) - if issubclass(Estimator, RegressorMixin): - estimator = Estimator( - estimators=[ - ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), - ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), - ] - ) - else: - estimator = Estimator( - estimators=[ - ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), - ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), - ] - ) - else: - msg = ( - f"Can't instantiate estimator {Estimator.__name__} " - f"parameters {required_parameters}" - ) - # raise additional warning to be shown by pytest - warnings.warn(msg, SkipTestWarning) - raise SkipTest(msg) - else: - estimator = Estimator() - return estimator - - def _maybe_mark_xfail(estimator, check, pytest): # Mark (estimator, check) pairs as XFAIL if needed (see conditions in # _should_be_skipped_or_marked()) @@ -682,124 +592,6 @@ def _regression_dataset(): return REGRESSION_DATASET -def _set_checking_parameters(estimator): - # set parameters to speed up some estimators and - # avoid deprecated behaviour - params = estimator.get_params() - name = estimator.__class__.__name__ - if name == "TSNE": - estimator.set_params(perplexity=2) - if "n_iter" in params and name != "TSNE": - estimator.set_params(n_iter=5) - if "max_iter" in params: - if estimator.max_iter is not None: - estimator.set_params(max_iter=min(5, estimator.max_iter)) - # LinearSVR, LinearSVC - if name in ["LinearSVR", "LinearSVC"]: - estimator.set_params(max_iter=20) - # NMF - if name == "NMF": - estimator.set_params(max_iter=500) - # DictionaryLearning - if name == "DictionaryLearning": - estimator.set_params(max_iter=20, transform_algorithm="lasso_lars") - # MiniBatchNMF - if estimator.__class__.__name__ == "MiniBatchNMF": - estimator.set_params(max_iter=20, fresh_restarts=True) - # MLP - if name in ["MLPClassifier", "MLPRegressor"]: - estimator.set_params(max_iter=100) - # MiniBatchDictionaryLearning - if name == "MiniBatchDictionaryLearning": - estimator.set_params(max_iter=5) - - if "n_resampling" in params: - # randomized lasso - estimator.set_params(n_resampling=5) - if "n_estimators" in params: - estimator.set_params(n_estimators=min(5, estimator.n_estimators)) - if "max_trials" in params: - # RANSAC - estimator.set_params(max_trials=10) - if "n_init" in params: - # K-Means - estimator.set_params(n_init=2) - if "batch_size" in params and not name.startswith("MLP"): - estimator.set_params(batch_size=10) - - if name == "MeanShift": - # In the case of check_fit2d_1sample, bandwidth is set to None and - # is thus estimated. De facto it is 0.0 as a single sample is provided - # and this makes the test fails. Hence we give it a placeholder value. - estimator.set_params(bandwidth=1.0) - - if name == "TruncatedSVD": - # TruncatedSVD doesn't run with n_components = n_features - # This is ugly :-/ - estimator.n_components = 1 - - if name == "LassoLarsIC": - # Noise variance estimation does not work when `n_samples < n_features`. - # We need to provide the noise variance explicitly. - estimator.set_params(noise_variance=1.0) - - if hasattr(estimator, "n_clusters"): - estimator.n_clusters = min(estimator.n_clusters, 2) - - if hasattr(estimator, "n_best"): - estimator.n_best = 1 - - if name == "SelectFdr": - # be tolerant of noisy datasets (not actually speed) - estimator.set_params(alpha=0.5) - - if name == "TheilSenRegressor": - estimator.max_subpopulation = 100 - - if isinstance(estimator, BaseRandomProjection): - # Due to the jl lemma and often very few samples, the number - # of components of the random matrix projection will be probably - # greater than the number of features. - # So we impose a smaller number (avoid "auto" mode) - estimator.set_params(n_components=2) - - if isinstance(estimator, SelectKBest): - # SelectKBest has a default of k=10 - # which is more feature than we have in most case. - estimator.set_params(k=1) - - if name in ("HistGradientBoostingClassifier", "HistGradientBoostingRegressor"): - # The default min_samples_leaf (20) isn't appropriate for small - # datasets (only very shallow trees are built) that the checks use. - estimator.set_params(min_samples_leaf=5) - - if name == "DummyClassifier": - # the default strategy prior would output constant predictions and fail - # for check_classifiers_predictions - estimator.set_params(strategy="stratified") - - # Speed-up by reducing the number of CV or splits for CV estimators - loo_cv = ["RidgeCV", "RidgeClassifierCV"] - if name not in loo_cv and hasattr(estimator, "cv"): - estimator.set_params(cv=3) - if hasattr(estimator, "n_splits"): - estimator.set_params(n_splits=3) - - if name == "OneHotEncoder": - estimator.set_params(handle_unknown="ignore") - - if name in CROSS_DECOMPOSITION: - estimator.set_params(n_components=1) - - # Default "auto" parameter can lead to different ordering of eigenvalues on - # windows: #24105 - if name == "SpectralEmbedding": - estimator.set_params(eigen_tol=1e-5) - - if name == "HDBSCAN": - estimator.set_params(min_samples=1) - - class _NotAnArray: """An object that is convertible to an array. diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 8d03b5c851fba..f01bb99763cd4 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -30,6 +30,7 @@ from sklearn.svm import SVC, NuSVC from sklearn.utils import _array_api, all_estimators, deprecated from sklearn.utils._param_validation import Interval, StrOptions +from sklearn.utils._test_common.instance_generator import _set_checking_parameters from sklearn.utils._testing import ( MinimalClassifier, MinimalRegressor, @@ -40,7 +41,6 @@ ) from sklearn.utils.estimator_checks import ( _NotAnArray, - _set_checking_parameters, _yield_all_checks, check_array_api_input, check_class_weight_balanced_linear_classifier,