From d28c3cb7c3f465a668a655a36d477a31f360ae35 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 21 Aug 2024 15:50:51 +0200 Subject: [PATCH 01/11] TST allow categorisation of tests into API and legacy --- sklearn/utils/estimator_checks.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 42edfe0d4d3c4..01b1276edcdaa 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -90,13 +90,17 @@ CROSS_DECOMPOSITION = ["PLSCanonical", "PLSRegression", "CCA", "PLSSVD"] +def _yield_api_checks(estimator): + yield check_no_attributes_set_in_init + yield check_fit_score_takes_y + yield check_estimators_overwrite_params + + def _yield_checks(estimator): name = estimator.__class__.__name__ tags = _safe_tags(estimator) - yield check_no_attributes_set_in_init yield check_estimators_dtypes - yield check_fit_score_takes_y if has_fit_parameter(estimator, "sample_weight"): yield check_sample_weights_pandas_series yield check_sample_weights_not_an_array @@ -129,7 +133,6 @@ def _yield_checks(estimator): # Check that pairwise estimator throws error on non-square input yield check_nonsquare_error - yield check_estimators_overwrite_params if hasattr(estimator, "sparsify"): yield check_sparsify_coefficients @@ -323,7 +326,7 @@ def _yield_array_api_checks(estimator): ) -def _yield_all_checks(estimator): +def _yield_all_checks(estimator, legacy: bool): name = estimator.__class__.__name__ tags = _safe_tags(estimator) if "2darray" not in tags["X_types"]: @@ -341,6 +344,12 @@ def _yield_all_checks(estimator): ) return + for check in _yield_api_checks(estimator): + yield check + + if not legacy: + return + for check in _yield_checks(estimator): yield check if is_classifier(estimator): @@ -513,9 +522,14 @@ def _should_be_skipped_or_marked(estimator, check): return False, "placeholder reason that will never be used" -def parametrize_with_checks(estimators): +def parametrize_with_checks(estimators, legacy=True): """Pytest specific decorator for parametrizing estimator checks. + Checks are categorised into the following groups: + + - API checks: a set of checks to ensure API compatibility with scikit-learn + - legacy: a set of checks which gradually will be grouped into other categories + The `id` of each check is set to be a pprint version of the estimator and the name of the check with its keyword arguments. This allows to use `pytest -k` to specify which tests to run:: @@ -533,6 +547,11 @@ def parametrize_with_checks(estimators): .. versionadded:: 0.24 + legacy : bool (default=True) + Whether to include legacy checks. + + .. versionadded:: 1.6 + Returns ------- decorator : `pytest.mark.parametrize` @@ -566,7 +585,7 @@ def parametrize_with_checks(estimators): def checks_generator(): for estimator in estimators: name = type(estimator).__name__ - for check in _yield_all_checks(estimator): + for check in _yield_all_checks(estimator, legacy=legacy): check = partial(check, name) yield _maybe_mark_xfail(estimator, check, pytest) From 13a8e27c422a2d383585f7b098216984e797234e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 22 Aug 2024 08:09:41 +0200 Subject: [PATCH 02/11] TST refactor instance generation and parameter setting --- doc/sphinxext/allow_nan_estimators.py | 2 +- sklearn/decomposition/tests/test_pca.py | 2 +- sklearn/linear_model/tests/test_ridge.py | 2 +- sklearn/preprocessing/tests/test_data.py | 2 +- sklearn/tests/test_common.py | 104 +---- sklearn/tests/test_docstring_parameters.py | 2 +- sklearn/utils/_test_common/__init__.py | 2 + .../utils/_test_common/instance_generator.py | 441 ++++++++++++++++++ sklearn/utils/estimator_checks.py | 220 +-------- sklearn/utils/tests/test_estimator_checks.py | 2 +- 10 files changed, 466 insertions(+), 313 deletions(-) create mode 100644 sklearn/utils/_test_common/__init__.py create mode 100644 sklearn/utils/_test_common/instance_generator.py diff --git a/doc/sphinxext/allow_nan_estimators.py b/doc/sphinxext/allow_nan_estimators.py index 89d7077bce2b5..00a6ddc0048e9 100755 --- a/doc/sphinxext/allow_nan_estimators.py +++ b/doc/sphinxext/allow_nan_estimators.py @@ -4,8 +4,8 @@ from docutils.parsers.rst import Directive from sklearn.utils import all_estimators +from sklearn.utils._test_common.instance_generator import _construct_instance from sklearn.utils._testing import SkipTest -from sklearn.utils.estimator_checks import _construct_instance class AllowNanEstimators(Directive): diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index bd7f60061abdc..52f769bfb9001 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -17,9 +17,9 @@ yield_namespace_device_dtype_combinations, ) from sklearn.utils._array_api import device as array_device +from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids from sklearn.utils._testing import _array_api_for_tests, assert_allclose from sklearn.utils.estimator_checks import ( - _get_check_estimator_ids, check_array_api_input_and_values, ) from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index 9db58dd499269..c727d268e0ebc 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -48,6 +48,7 @@ yield_namespace_device_dtype_combinations, yield_namespaces, ) +from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids from sklearn.utils._testing import ( assert_allclose, assert_almost_equal, @@ -57,7 +58,6 @@ ) from sklearn.utils.estimator_checks import ( _array_api_for_tests, - _get_check_estimator_ids, check_array_api_input_and_values, ) from sklearn.utils.fixes import ( diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 5d254e491b400..049b188cf66a7 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -40,6 +40,7 @@ from sklearn.utils._array_api import ( yield_namespace_device_dtype_combinations, ) +from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids from sklearn.utils._testing import ( _convert_container, assert_allclose, @@ -51,7 +52,6 @@ skip_if_32bit, ) from sklearn.utils.estimator_checks import ( - _get_check_estimator_ids, check_array_api_input_and_values, ) from sklearn.utils.fixes import ( diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 3a61503530f23..467c7db9a3d21 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -11,7 +11,7 @@ import warnings from functools import partial from inspect import isgenerator, signature -from itertools import chain, product +from itertools import chain import numpy as np import pytest @@ -26,25 +26,13 @@ MeanShift, SpectralClustering, ) -from sklearn.compose import ColumnTransformer from sklearn.datasets import make_blobs -from sklearn.decomposition import PCA from sklearn.exceptions import ConvergenceWarning, FitFailedWarning # make it possible to discover experimental estimators when calling `all_estimators` -from sklearn.experimental import ( - enable_halving_search_cv, # noqa - enable_iterative_imputer, # noqa -) -from sklearn.linear_model import LogisticRegression, Ridge +from sklearn.linear_model import LogisticRegression from sklearn.linear_model._base import LinearClassifierMixin from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding -from sklearn.model_selection import ( - GridSearchCV, - HalvingGridSearchCV, - HalvingRandomSearchCV, - RandomizedSearchCV, -) from sklearn.neighbors import ( KNeighborsClassifier, KNeighborsRegressor, @@ -52,7 +40,7 @@ RadiusNeighborsClassifier, RadiusNeighborsRegressor, ) -from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.pipeline import make_pipeline from sklearn.preprocessing import ( FunctionTransformer, MinMaxScaler, @@ -62,15 +50,19 @@ from sklearn.semi_supervised import LabelPropagation, LabelSpreading from sklearn.utils import all_estimators from sklearn.utils._tags import _DEFAULT_TAGS, _safe_tags +from sklearn.utils._test_common.instance_generator import ( + _generate_column_transformer_instances, + _generate_pipeline, + _generate_search_cv_instances, + _get_check_estimator_ids, + _set_checking_parameters, + _tested_estimators, +) from sklearn.utils._testing import ( SkipTest, ignore_warnings, - set_random_state, ) from sklearn.utils.estimator_checks import ( - _construct_instance, - _get_check_estimator_ids, - _set_checking_parameters, check_class_weight_balanced_linear_classifier, check_dataframe_column_names_consistency, check_estimator, @@ -139,26 +131,6 @@ def test_get_check_estimator_ids(val, expected): assert _get_check_estimator_ids(val) == expected -def _tested_estimators(type_filter=None): - for name, Estimator in all_estimators(type_filter=type_filter): - try: - estimator = _construct_instance(Estimator) - except SkipTest: - continue - - yield estimator - - -def _generate_pipeline(): - for final_estimator in [Ridge(), LogisticRegression()]: - yield Pipeline( - steps=[ - ("scaler", StandardScaler()), - ("final_estimator", final_estimator), - ] - ) - - @parametrize_with_checks(list(chain(_tested_estimators(), _generate_pipeline()))) def test_estimators(estimator, check, request): # Common tests for estimator instances @@ -282,60 +254,6 @@ def test_class_support_removed(): parametrize_with_checks([LogisticRegression]) -def _generate_column_transformer_instances(): - yield ColumnTransformer( - transformers=[ - ("trans1", StandardScaler(), [0, 1]), - ] - ) - - -def _generate_search_cv_instances(): - for SearchCV, (Estimator, param_grid) in product( - [ - GridSearchCV, - HalvingGridSearchCV, - RandomizedSearchCV, - HalvingGridSearchCV, - ], - [ - (Ridge, {"alpha": [0.1, 1.0]}), - (LogisticRegression, {"C": [0.1, 1.0]}), - ], - ): - init_params = signature(SearchCV).parameters - extra_params = ( - {"min_resources": "smallest"} if "min_resources" in init_params else {} - ) - search_cv = SearchCV( - Estimator(), param_grid, cv=2, error_score="raise", **extra_params - ) - set_random_state(search_cv) - yield search_cv - - for SearchCV, (Estimator, param_grid) in product( - [ - GridSearchCV, - HalvingGridSearchCV, - RandomizedSearchCV, - HalvingRandomSearchCV, - ], - [ - (Ridge, {"ridge__alpha": [0.1, 1.0]}), - (LogisticRegression, {"logisticregression__C": [0.1, 1.0]}), - ], - ): - init_params = signature(SearchCV).parameters - extra_params = ( - {"min_resources": "smallest"} if "min_resources" in init_params else {} - ) - search_cv = SearchCV( - make_pipeline(PCA(), Estimator()), param_grid, cv=2, **extra_params - ).set_params(error_score="raise") - set_random_state(search_cv) - yield search_cv - - @parametrize_with_checks(list(_generate_search_cv_instances())) def test_search_cv(estimator, check, request): # Common tests for SearchCV instances diff --git a/sklearn/tests/test_docstring_parameters.py b/sklearn/tests/test_docstring_parameters.py index 3af463b783bc3..687b85ed00187 100644 --- a/sklearn/tests/test_docstring_parameters.py +++ b/sklearn/tests/test_docstring_parameters.py @@ -22,6 +22,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.preprocessing import FunctionTransformer from sklearn.utils import all_estimators +from sklearn.utils._test_common.instance_generator import _construct_instance from sklearn.utils._testing import ( _get_func_name, check_docstring_parameters, @@ -29,7 +30,6 @@ ) from sklearn.utils.deprecation import _is_deprecated from sklearn.utils.estimator_checks import ( - _construct_instance, _enforce_estimator_tags_X, _enforce_estimator_tags_y, ) diff --git a/sklearn/utils/_test_common/__init__.py b/sklearn/utils/_test_common/__init__.py new file mode 100644 index 0000000000000..67dd18fb94b59 --- /dev/null +++ b/sklearn/utils/_test_common/__init__.py @@ -0,0 +1,2 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py new file mode 100644 index 0000000000000..3e215111adcda --- /dev/null +++ b/sklearn/utils/_test_common/instance_generator.py @@ -0,0 +1,441 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + + +import re +import warnings +from functools import partial +from inspect import isfunction, signature +from itertools import product + +from sklearn import config_context +from sklearn.base import RegressorMixin +from sklearn.calibration import CalibratedClassifierCV +from sklearn.cluster import ( + HDBSCAN, + AffinityPropagation, + AgglomerativeClustering, + Birch, + BisectingKMeans, + FeatureAgglomeration, + KMeans, + MeanShift, + MiniBatchKMeans, + SpectralBiclustering, + SpectralClustering, + SpectralCoclustering, +) +from sklearn.compose import ColumnTransformer +from sklearn.covariance import GraphicalLasso, GraphicalLassoCV +from sklearn.cross_decomposition import CCA, PLSSVD, PLSCanonical, PLSRegression +from sklearn.decomposition import ( + NMF, + PCA, + DictionaryLearning, + FactorAnalysis, + FastICA, + IncrementalPCA, + LatentDirichletAllocation, + MiniBatchDictionaryLearning, + MiniBatchNMF, + MiniBatchSparsePCA, + SparsePCA, + TruncatedSVD, +) +from sklearn.dummy import DummyClassifier +from sklearn.ensemble import ( + AdaBoostClassifier, + AdaBoostRegressor, + BaggingClassifier, + BaggingRegressor, + ExtraTreesClassifier, + ExtraTreesRegressor, + GradientBoostingClassifier, + GradientBoostingRegressor, + HistGradientBoostingClassifier, + HistGradientBoostingRegressor, + IsolationForest, + RandomForestClassifier, + RandomForestRegressor, + RandomTreesEmbedding, + StackingClassifier, + StackingRegressor, +) +from sklearn.exceptions import SkipTestWarning +from sklearn.experimental import enable_halving_search_cv # noqa +from sklearn.feature_selection import ( + RFECV, + SelectFdr, + SelectFromModel, + SelectKBest, + SequentialFeatureSelector, +) +from sklearn.linear_model import ( + ARDRegression, + BayesianRidge, + ElasticNet, + ElasticNetCV, + GammaRegressor, + HuberRegressor, + LarsCV, + Lasso, + LassoCV, + LassoLars, + LassoLarsCV, + LassoLarsIC, + LinearRegression, + LogisticRegression, + LogisticRegressionCV, + MultiTaskElasticNet, + MultiTaskElasticNetCV, + MultiTaskLasso, + MultiTaskLassoCV, + OrthogonalMatchingPursuitCV, + PassiveAggressiveClassifier, + PassiveAggressiveRegressor, + Perceptron, + PoissonRegressor, + RANSACRegressor, + Ridge, + SGDClassifier, + SGDOneClassSVM, + SGDRegressor, + TheilSenRegressor, + TweedieRegressor, +) +from sklearn.manifold import MDS, TSNE, LocallyLinearEmbedding, SpectralEmbedding +from sklearn.mixture import BayesianGaussianMixture, GaussianMixture +from sklearn.model_selection import ( + GridSearchCV, + HalvingGridSearchCV, + HalvingRandomSearchCV, + RandomizedSearchCV, + TunedThresholdClassifierCV, +) +from sklearn.multioutput import ClassifierChain, RegressorChain +from sklearn.neighbors import NeighborhoodComponentsAnalysis +from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor +from sklearn.pipeline import Pipeline, make_pipeline +from sklearn.preprocessing import OneHotEncoder, StandardScaler, TargetEncoder +from sklearn.random_projection import ( + GaussianRandomProjection, + SparseRandomProjection, +) +from sklearn.semi_supervised import ( + LabelPropagation, + LabelSpreading, + SelfTrainingClassifier, +) +from sklearn.svm import SVC, SVR, LinearSVC, LinearSVR, NuSVC, NuSVR, OneClassSVM +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.utils import all_estimators +from sklearn.utils._testing import SkipTest, set_random_state + +CROSS_DECOMPOSITION = ["PLSCanonical", "PLSRegression", "CCA", "PLSSVD"] + +# The following dictionary is to indicate constructor arguments suitable for the test +# suite, which uses very small datasets, and is intended to run rather quickly. +TEST_PARAMS = { + AdaBoostClassifier: dict(n_estimators=5), + AdaBoostRegressor: dict(n_estimators=5), + AffinityPropagation: dict(max_iter=5), + AgglomerativeClustering: dict(n_clusters=2), + ARDRegression: dict(max_iter=5), + BaggingClassifier: dict(n_estimators=5), + BaggingRegressor: dict(n_estimators=5), + BayesianGaussianMixture: dict(n_init=2, max_iter=5), + BayesianRidge: dict(max_iter=5), + BernoulliRBM: dict(n_iter=5, batch_size=10), + Birch: dict(n_clusters=2), + BisectingKMeans: dict(n_init=2, n_clusters=2, max_iter=5), + CalibratedClassifierCV: dict(cv=3), + CCA: dict(n_components=1, max_iter=5), + ClassifierChain: dict(cv=3), + DictionaryLearning: dict(max_iter=20, transform_algorithm="lasso_lars"), + # the default strategy prior would output constant predictions and fail + # for check_classifiers_predictions + DummyClassifier: dict(strategy="stratified"), + ElasticNetCV: dict(max_iter=5, cv=3), + ElasticNet: dict(max_iter=5), + ExtraTreesClassifier: dict(n_estimators=5), + ExtraTreesRegressor: dict(n_estimators=5), + FactorAnalysis: dict(max_iter=5), + FastICA: dict(max_iter=5), + FeatureAgglomeration: dict(n_clusters=2), + GammaRegressor: dict(max_iter=5), + GaussianMixture: dict(n_init=2, max_iter=5), + # Due to the jl lemma and often very few samples, the number + # of components of the random matrix projection will be probably + # greater than the number of features. + # So we impose a smaller number (avoid "auto" mode) + GaussianRandomProjection: dict(n_components=2), + GradientBoostingClassifier: dict(n_estimators=5), + GradientBoostingRegressor: dict(n_estimators=5), + GraphicalLassoCV: dict(max_iter=5, cv=3), + GraphicalLasso: dict(max_iter=5), + GridSearchCV: dict(cv=3), + HalvingGridSearchCV: dict(cv=3), + HalvingRandomSearchCV: dict(cv=3), + HDBSCAN: dict(min_samples=1), + # The default min_samples_leaf (20) isn't appropriate for small + # datasets (only very shallow trees are built) that the checks use. + HistGradientBoostingClassifier: dict(max_iter=5, min_samples_leaf=5), + HistGradientBoostingRegressor: dict(max_iter=5, min_samples_leaf=5), + HuberRegressor: dict(max_iter=5), + IncrementalPCA: dict(batch_size=10), + IsolationForest: dict(n_estimators=5), + KMeans: dict(n_init=2, n_clusters=2, max_iter=5), + LabelPropagation: dict(max_iter=5), + LabelSpreading: dict(max_iter=5), + LarsCV: dict(max_iter=5, cv=3), + LassoCV: dict(max_iter=5, cv=3), + Lasso: dict(max_iter=5), + LassoLarsCV: dict(max_iter=5, cv=3), + LassoLars: dict(max_iter=5), + # Noise variance estimation does not work when `n_samples < n_features`. + # We need to provide the noise variance explicitly. + LassoLarsIC: dict(max_iter=5, noise_variance=1.0), + LatentDirichletAllocation: dict(max_iter=5, batch_size=10), + LinearSVR: dict(max_iter=20), + LinearSVC: dict(max_iter=20), + LocallyLinearEmbedding: dict(max_iter=5), + LogisticRegressionCV: dict(max_iter=5, cv=3), + LogisticRegression: dict(max_iter=5), + MDS: dict(n_init=2, max_iter=5), + # In the case of check_fit2d_1sample, bandwidth is set to None and + # is thus estimated. De facto it is 0.0 as a single sample is provided + # and this makes the test fails. Hence we give it a placeholder value. + MeanShift: dict(max_iter=5, bandwidth=1.0), + MiniBatchDictionaryLearning: dict(batch_size=10, max_iter=5), + MiniBatchKMeans: dict(n_init=2, n_clusters=2, max_iter=5, batch_size=10), + MiniBatchNMF: dict(batch_size=10, max_iter=20, fresh_restarts=True), + MiniBatchSparsePCA: dict(max_iter=5, batch_size=10), + MLPClassifier: dict(max_iter=100), + MLPRegressor: dict(max_iter=100), + MultiTaskElasticNetCV: dict(max_iter=5, cv=3), + MultiTaskElasticNet: dict(max_iter=5), + MultiTaskLassoCV: dict(max_iter=5, cv=3), + MultiTaskLasso: dict(max_iter=5), + NeighborhoodComponentsAnalysis: dict(max_iter=5), + NMF: dict(max_iter=500), + NuSVC: dict(max_iter=-1), + NuSVR: dict(max_iter=-1), + OneClassSVM: dict(max_iter=-1), + OneHotEncoder: dict(handle_unknown="ignore"), + OrthogonalMatchingPursuitCV: dict(cv=3), + PassiveAggressiveClassifier: dict(max_iter=5), + PassiveAggressiveRegressor: dict(max_iter=5), + Perceptron: dict(max_iter=5), + PLSCanonical: dict(n_components=1, max_iter=5), + PLSRegression: dict(n_components=1, max_iter=5), + PLSSVD: dict(n_components=1), + PoissonRegressor: dict(max_iter=5), + RandomForestClassifier: dict(n_estimators=5), + RandomForestRegressor: dict(n_estimators=5), + RandomizedSearchCV: dict(n_iter=5, cv=3), + RandomTreesEmbedding: dict(n_estimators=5), + RANSACRegressor: dict(max_trials=10), + RegressorChain: dict(cv=3), + RFECV: dict(cv=3), + # be tolerant of noisy datasets (not actually speed) + SelectFdr: dict(alpha=0.5), + # SelectKBest has a default of k=10 + # which is more feature than we have in most case. + SelectKBest: dict(k=1), + SelfTrainingClassifier: dict(max_iter=5), + SequentialFeatureSelector: dict(cv=3), + SGDClassifier: dict(max_iter=5), + SGDOneClassSVM: dict(max_iter=5), + SGDRegressor: dict(max_iter=5), + SparsePCA: dict(max_iter=5), + # Due to the jl lemma and often very few samples, the number + # of components of the random matrix projection will be probably + # greater than the number of features. + # So we impose a smaller number (avoid "auto" mode) + SparseRandomProjection: dict(n_components=2), + SpectralBiclustering: dict(n_init=2, n_best=1, n_clusters=2), + SpectralClustering: dict(n_init=2, n_clusters=2), + SpectralCoclustering: dict(n_init=2, n_clusters=2), + # Default "auto" parameter can lead to different ordering of eigenvalues on + # windows: #24105 + SpectralEmbedding: dict(eigen_tol=1e-5), + StackingClassifier: dict(cv=3), + StackingRegressor: dict(cv=3), + SVC: dict(max_iter=-1), + SVR: dict(max_iter=-1), + TargetEncoder: dict(cv=3), + TheilSenRegressor: dict(max_iter=5, max_subpopulation=100), + # TruncatedSVD doesn't run with n_components = n_features + TruncatedSVD: dict(n_iter=5, n_components=1), + TSNE: dict(perplexity=2), + TunedThresholdClassifierCV: dict(cv=3), + TweedieRegressor: dict(max_iter=5), +} + + +def _set_checking_parameters(estimator): + # set parameters to speed up some estimators and + # avoid deprecated behaviour + params = estimator.get_params() + name = estimator.__class__.__name__ + + if type(estimator) in TEST_PARAMS: + test_params = TEST_PARAMS[type(estimator)] + estimator.set_params(**test_params) + + +def _tested_estimators(type_filter=None): + for name, Estimator in all_estimators(type_filter=type_filter): + try: + estimator = _construct_instance(Estimator) + except SkipTest: + continue + + yield estimator + + +def _generate_pipeline(): + for final_estimator in [Ridge(), LogisticRegression()]: + yield Pipeline( + steps=[ + ("scaler", StandardScaler()), + ("final_estimator", final_estimator), + ] + ) + + +def _construct_instance(Estimator): + """Construct Estimator instance if possible.""" + required_parameters = getattr(Estimator, "_required_parameters", []) + if len(required_parameters): + if required_parameters in (["estimator"], ["base_estimator"]): + # `RANSACRegressor` will raise an error with any model other + # than `LinearRegression` if we don't fix `min_samples` parameter. + # For common test, we can enforce using `LinearRegression` that + # is the default estimator in `RANSACRegressor` instead of `Ridge`. + if issubclass(Estimator, RANSACRegressor): + estimator = Estimator(LinearRegression()) + elif issubclass(Estimator, RegressorMixin): + estimator = Estimator(Ridge()) + elif issubclass(Estimator, SelectFromModel): + # Increases coverage because SGDRegressor has partial_fit + estimator = Estimator(SGDRegressor(random_state=0)) + else: + estimator = Estimator(LogisticRegression(C=1)) + elif required_parameters in (["estimators"],): + # Heterogeneous ensemble classes (i.e. stacking, voting) + if issubclass(Estimator, RegressorMixin): + estimator = Estimator( + estimators=[ + ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), + ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), + ] + ) + else: + estimator = Estimator( + estimators=[ + ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), + ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), + ] + ) + else: + msg = ( + f"Can't instantiate estimator {Estimator.__name__} " + f"parameters {required_parameters}" + ) + # raise additional warning to be shown by pytest + warnings.warn(msg, SkipTestWarning) + raise SkipTest(msg) + else: + estimator = Estimator() + return estimator + + +def _get_check_estimator_ids(obj): + """Create pytest ids for checks. + + When `obj` is an estimator, this returns the pprint version of the + estimator (with `print_changed_only=True`). When `obj` is a function, the + name of the function is returned with its keyword arguments. + + `_get_check_estimator_ids` is designed to be used as the `id` in + `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)` + is yielding estimators and checks. + + Parameters + ---------- + obj : estimator or function + Items generated by `check_estimator`. + + Returns + ------- + id : str or None + + See Also + -------- + check_estimator + """ + if isfunction(obj): + return obj.__name__ + if isinstance(obj, partial): + if not obj.keywords: + return obj.func.__name__ + kwstring = ",".join(["{}={}".format(k, v) for k, v in obj.keywords.items()]) + return "{}({})".format(obj.func.__name__, kwstring) + if hasattr(obj, "get_params"): + with config_context(print_changed_only=True): + return re.sub(r"\s", "", str(obj)) + + +def _generate_column_transformer_instances(): + yield ColumnTransformer( + transformers=[ + ("trans1", StandardScaler(), [0, 1]), + ] + ) + + +def _generate_search_cv_instances(): + for SearchCV, (Estimator, param_grid) in product( + [ + GridSearchCV, + HalvingGridSearchCV, + RandomizedSearchCV, + HalvingGridSearchCV, + ], + [ + (Ridge, {"alpha": [0.1, 1.0]}), + (LogisticRegression, {"C": [0.1, 1.0]}), + ], + ): + init_params = signature(SearchCV).parameters + extra_params = ( + {"min_resources": "smallest"} if "min_resources" in init_params else {} + ) + search_cv = SearchCV( + Estimator(), param_grid, cv=2, error_score="raise", **extra_params + ) + set_random_state(search_cv) + yield search_cv + + for SearchCV, (Estimator, param_grid) in product( + [ + GridSearchCV, + HalvingGridSearchCV, + RandomizedSearchCV, + HalvingRandomSearchCV, + ], + [ + (Ridge, {"ridge__alpha": [0.1, 1.0]}), + (LogisticRegression, {"logisticregression__C": [0.1, 1.0]}), + ], + ): + init_params = signature(SearchCV).parameters + extra_params = ( + {"min_resources": "smallest"} if "min_resources" in init_params else {} + ) + search_cv = SearchCV( + make_pipeline(PCA(), Estimator()), param_grid, cv=2, **extra_params + ).set_params(error_score="raise") + set_random_state(search_cv) + yield search_cv diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 42edfe0d4d3c4..745503c54a7aa 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -9,7 +9,7 @@ from contextlib import nullcontext from copy import deepcopy from functools import partial, wraps -from inspect import isfunction, signature +from inspect import signature from numbers import Integral, Real import joblib @@ -20,7 +20,6 @@ from .. import config_context from ..base import ( ClusterMixin, - RegressorMixin, clone, is_classifier, is_outlier_detector, @@ -34,22 +33,12 @@ make_regression, ) from ..exceptions import DataConversionWarning, NotFittedError, SkipTestWarning -from ..feature_selection import SelectFromModel, SelectKBest -from ..linear_model import ( - LinearRegression, - LogisticRegression, - RANSACRegressor, - Ridge, - SGDRegressor, -) from ..metrics import accuracy_score, adjusted_rand_score, f1_score from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel from ..model_selection import ShuffleSplit, train_test_split from ..model_selection._validation import _safe_split from ..pipeline import make_pipeline from ..preprocessing import StandardScaler, scale -from ..random_projection import BaseRandomProjection -from ..tree import DecisionTreeClassifier, DecisionTreeRegressor from ..utils._array_api import ( _atol_for_type, _convert_to_numpy, @@ -69,6 +58,11 @@ _DEFAULT_TAGS, _safe_tags, ) +from ._test_common.instance_generator import ( + CROSS_DECOMPOSITION, + _construct_instance, + _get_check_estimator_ids, +) from ._testing import ( SkipTest, _array_api_for_tests, @@ -87,7 +81,6 @@ from .validation import _num_samples, check_is_fitted, has_fit_parameter REGRESSION_DATASET = None -CROSS_DECOMPOSITION = ["PLSCanonical", "PLSRegression", "CCA", "PLSSVD"] def _yield_checks(estimator): @@ -380,89 +373,6 @@ def _yield_all_checks(estimator): yield check_fit_non_negative -def _get_check_estimator_ids(obj): - """Create pytest ids for checks. - - When `obj` is an estimator, this returns the pprint version of the - estimator (with `print_changed_only=True`). When `obj` is a function, the - name of the function is returned with its keyword arguments. - - `_get_check_estimator_ids` is designed to be used as the `id` in - `pytest.mark.parametrize` where `check_estimator(..., generate_only=True)` - is yielding estimators and checks. - - Parameters - ---------- - obj : estimator or function - Items generated by `check_estimator`. - - Returns - ------- - id : str or None - - See Also - -------- - check_estimator - """ - if isfunction(obj): - return obj.__name__ - if isinstance(obj, partial): - if not obj.keywords: - return obj.func.__name__ - kwstring = ",".join(["{}={}".format(k, v) for k, v in obj.keywords.items()]) - return "{}({})".format(obj.func.__name__, kwstring) - if hasattr(obj, "get_params"): - with config_context(print_changed_only=True): - return re.sub(r"\s", "", str(obj)) - - -def _construct_instance(Estimator): - """Construct Estimator instance if possible.""" - required_parameters = getattr(Estimator, "_required_parameters", []) - if len(required_parameters): - if required_parameters in (["estimator"], ["base_estimator"]): - # `RANSACRegressor` will raise an error with any model other - # than `LinearRegression` if we don't fix `min_samples` parameter. - # For common test, we can enforce using `LinearRegression` that - # is the default estimator in `RANSACRegressor` instead of `Ridge`. - if issubclass(Estimator, RANSACRegressor): - estimator = Estimator(LinearRegression()) - elif issubclass(Estimator, RegressorMixin): - estimator = Estimator(Ridge()) - elif issubclass(Estimator, SelectFromModel): - # Increases coverage because SGDRegressor has partial_fit - estimator = Estimator(SGDRegressor(random_state=0)) - else: - estimator = Estimator(LogisticRegression(C=1)) - elif required_parameters in (["estimators"],): - # Heterogeneous ensemble classes (i.e. stacking, voting) - if issubclass(Estimator, RegressorMixin): - estimator = Estimator( - estimators=[ - ("est1", DecisionTreeRegressor(max_depth=3, random_state=0)), - ("est2", DecisionTreeRegressor(max_depth=3, random_state=1)), - ] - ) - else: - estimator = Estimator( - estimators=[ - ("est1", DecisionTreeClassifier(max_depth=3, random_state=0)), - ("est2", DecisionTreeClassifier(max_depth=3, random_state=1)), - ] - ) - else: - msg = ( - f"Can't instantiate estimator {Estimator.__name__} " - f"parameters {required_parameters}" - ) - # raise additional warning to be shown by pytest - warnings.warn(msg, SkipTestWarning) - raise SkipTest(msg) - else: - estimator = Estimator() - return estimator - - def _maybe_mark_xfail(estimator, check, pytest): # Mark (estimator, check) pairs as XFAIL if needed (see conditions in # _should_be_skipped_or_marked()) @@ -672,124 +582,6 @@ def _regression_dataset(): return REGRESSION_DATASET -def _set_checking_parameters(estimator): - # set parameters to speed up some estimators and - # avoid deprecated behaviour - params = estimator.get_params() - name = estimator.__class__.__name__ - if name == "TSNE": - estimator.set_params(perplexity=2) - if "n_iter" in params and name != "TSNE": - estimator.set_params(n_iter=5) - if "max_iter" in params: - if estimator.max_iter is not None: - estimator.set_params(max_iter=min(5, estimator.max_iter)) - # LinearSVR, LinearSVC - if name in ["LinearSVR", "LinearSVC"]: - estimator.set_params(max_iter=20) - # NMF - if name == "NMF": - estimator.set_params(max_iter=500) - # DictionaryLearning - if name == "DictionaryLearning": - estimator.set_params(max_iter=20, transform_algorithm="lasso_lars") - # MiniBatchNMF - if estimator.__class__.__name__ == "MiniBatchNMF": - estimator.set_params(max_iter=20, fresh_restarts=True) - # MLP - if name in ["MLPClassifier", "MLPRegressor"]: - estimator.set_params(max_iter=100) - # MiniBatchDictionaryLearning - if name == "MiniBatchDictionaryLearning": - estimator.set_params(max_iter=5) - - if "n_resampling" in params: - # randomized lasso - estimator.set_params(n_resampling=5) - if "n_estimators" in params: - estimator.set_params(n_estimators=min(5, estimator.n_estimators)) - if "max_trials" in params: - # RANSAC - estimator.set_params(max_trials=10) - if "n_init" in params: - # K-Means - estimator.set_params(n_init=2) - if "batch_size" in params and not name.startswith("MLP"): - estimator.set_params(batch_size=10) - - if name == "MeanShift": - # In the case of check_fit2d_1sample, bandwidth is set to None and - # is thus estimated. De facto it is 0.0 as a single sample is provided - # and this makes the test fails. Hence we give it a placeholder value. - estimator.set_params(bandwidth=1.0) - - if name == "TruncatedSVD": - # TruncatedSVD doesn't run with n_components = n_features - # This is ugly :-/ - estimator.n_components = 1 - - if name == "LassoLarsIC": - # Noise variance estimation does not work when `n_samples < n_features`. - # We need to provide the noise variance explicitly. - estimator.set_params(noise_variance=1.0) - - if hasattr(estimator, "n_clusters"): - estimator.n_clusters = min(estimator.n_clusters, 2) - - if hasattr(estimator, "n_best"): - estimator.n_best = 1 - - if name == "SelectFdr": - # be tolerant of noisy datasets (not actually speed) - estimator.set_params(alpha=0.5) - - if name == "TheilSenRegressor": - estimator.max_subpopulation = 100 - - if isinstance(estimator, BaseRandomProjection): - # Due to the jl lemma and often very few samples, the number - # of components of the random matrix projection will be probably - # greater than the number of features. - # So we impose a smaller number (avoid "auto" mode) - estimator.set_params(n_components=2) - - if isinstance(estimator, SelectKBest): - # SelectKBest has a default of k=10 - # which is more feature than we have in most case. - estimator.set_params(k=1) - - if name in ("HistGradientBoostingClassifier", "HistGradientBoostingRegressor"): - # The default min_samples_leaf (20) isn't appropriate for small - # datasets (only very shallow trees are built) that the checks use. - estimator.set_params(min_samples_leaf=5) - - if name == "DummyClassifier": - # the default strategy prior would output constant predictions and fail - # for check_classifiers_predictions - estimator.set_params(strategy="stratified") - - # Speed-up by reducing the number of CV or splits for CV estimators - loo_cv = ["RidgeCV", "RidgeClassifierCV"] - if name not in loo_cv and hasattr(estimator, "cv"): - estimator.set_params(cv=3) - if hasattr(estimator, "n_splits"): - estimator.set_params(n_splits=3) - - if name == "OneHotEncoder": - estimator.set_params(handle_unknown="ignore") - - if name in CROSS_DECOMPOSITION: - estimator.set_params(n_components=1) - - # Default "auto" parameter can lead to different ordering of eigenvalues on - # windows: #24105 - if name == "SpectralEmbedding": - estimator.set_params(eigen_tol=1e-5) - - if name == "HDBSCAN": - estimator.set_params(min_samples=1) - - class _NotAnArray: """An object that is convertible to an array. diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 8ac7ac9db2e9a..066277ff24af6 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -30,6 +30,7 @@ from sklearn.svm import SVC, NuSVC from sklearn.utils import _array_api, all_estimators, deprecated from sklearn.utils._param_validation import Interval, StrOptions +from sklearn.utils._test_common.instance_generator import _set_checking_parameters from sklearn.utils._testing import ( MinimalClassifier, MinimalRegressor, @@ -40,7 +41,6 @@ ) from sklearn.utils.estimator_checks import ( _NotAnArray, - _set_checking_parameters, _yield_all_checks, check_array_api_input, check_class_weight_balanced_linear_classifier, From 3474eea41c12ded3e7b9db287b309d11f3bd3abd Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 22 Aug 2024 08:19:22 +0200 Subject: [PATCH 03/11] add legacy to check_estimator --- sklearn/utils/estimator_checks.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 01b1276edcdaa..dbd15bd1c2089 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -594,7 +594,7 @@ def checks_generator(): ) -def check_estimator(estimator=None, generate_only=False): +def check_estimator(estimator=None, generate_only=False, legacy=True): """Check if estimator adheres to scikit-learn conventions. This function will run an extensive test-suite for input validation, @@ -613,6 +613,11 @@ def check_estimator(estimator=None, generate_only=False): :func:`~sklearn.utils.estimator_checks.parametrize_with_checks`, making it easier to test multiple estimators. + Checks are categorised into the following groups: + + - API checks: a set of checks to ensure API compatibility with scikit-learn + - legacy: a set of checks which gradually will be grouped into other categories + Parameters ---------- estimator : estimator object @@ -630,6 +635,11 @@ def check_estimator(estimator=None, generate_only=False): .. versionadded:: 0.22 + legacy : bool (default=True) + Whether to include legacy checks. + + .. versionadded:: 1.6 + Returns ------- checks_generator : generator @@ -659,7 +669,7 @@ def check_estimator(estimator=None, generate_only=False): name = type(estimator).__name__ def checks_generator(): - for check in _yield_all_checks(estimator): + for check in _yield_all_checks(estimator, legacy=legacy): check = _maybe_skip(estimator, check) yield estimator, partial(check, name) From 3975f17f28d8ea98fb2f4fcc1911694348ae7ee8 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 22 Aug 2024 17:44:08 +0200 Subject: [PATCH 04/11] fix tests --- sklearn/utils/tests/test_estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 8ac7ac9db2e9a..7cf7e19f70cfe 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -1210,7 +1210,7 @@ def test_non_deterministic_estimator_skip_tests(): # check estimators with non_deterministic tag set to True # will skip certain tests, refer to issue #22313 for details for est in [MinimalTransformer, MinimalRegressor, MinimalClassifier]: - all_tests = list(_yield_all_checks(est())) + all_tests = list(_yield_all_checks(est(), legacy=True)) assert check_methods_sample_order_invariance in all_tests assert check_methods_subset_invariance in all_tests @@ -1218,7 +1218,7 @@ class Estimator(est): def _more_tags(self): return {"non_deterministic": True} - all_tests = list(_yield_all_checks(Estimator())) + all_tests = list(_yield_all_checks(Estimator(), legacy=True)) assert check_methods_sample_order_invariance not in all_tests assert check_methods_subset_invariance not in all_tests From e27edd3dc930a6fb213ce44b442045a2b4c6932e Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 23 Aug 2024 09:46:41 +0200 Subject: [PATCH 05/11] remove unnecessary vars --- sklearn/utils/_test_common/instance_generator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py index 3e215111adcda..c8887ad524dd4 100644 --- a/sklearn/utils/_test_common/instance_generator.py +++ b/sklearn/utils/_test_common/instance_generator.py @@ -276,9 +276,6 @@ def _set_checking_parameters(estimator): # set parameters to speed up some estimators and # avoid deprecated behaviour - params = estimator.get_params() - name = estimator.__class__.__name__ - if type(estimator) in TEST_PARAMS: test_params = TEST_PARAMS[type(estimator)] estimator.set_params(**test_params) From ad0560126ff6a8430d671f3cb5df679e5635ab04 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 25 Aug 2024 22:09:40 +0200 Subject: [PATCH 06/11] TST create dedicated dataframe related tests --- sklearn/tests/test_common.py | 9 -- sklearn/utils/estimator_checks.py | 176 +++++++++++++++++------------- 2 files changed, 102 insertions(+), 83 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 3a61503530f23..1d423ddb647df 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -78,7 +78,6 @@ check_global_output_transform_pandas, check_global_set_output_transform_polars, check_inplace_ensure_writeable, - check_n_features_in_after_fitting, check_param_validation, check_set_output_transform, check_set_output_transform_pandas, @@ -367,14 +366,6 @@ def test_valid_tag_types(estimator): assert isinstance(tag, correct_tags) -@pytest.mark.parametrize( - "estimator", _tested_estimators(), ids=_get_check_estimator_ids -) -def test_check_n_features_in_after_fitting(estimator): - _set_checking_parameters(estimator) - check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) - - def _estimators_that_predict_in_fit(): for estimator in _tested_estimators(): est_params = set(estimator.get_params()) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index dbd15bd1c2089..6621e68a70886 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -96,6 +96,10 @@ def _yield_api_checks(estimator): yield check_estimators_overwrite_params +def _yield_dataframe_checks(estimator): + yield check_n_features_in_after_fitting + + def _yield_checks(estimator): name = estimator.__class__.__name__ tags = _safe_tags(estimator) @@ -326,7 +330,7 @@ def _yield_array_api_checks(estimator): ) -def _yield_all_checks(estimator, legacy: bool): +def _yield_all_checks(estimator, dataframe: bool, legacy: bool): name = estimator.__class__.__name__ tags = _safe_tags(estimator) if "2darray" not in tags["X_types"]: @@ -347,6 +351,10 @@ def _yield_all_checks(estimator, legacy: bool): for check in _yield_api_checks(estimator): yield check + if dataframe: + for check in _yield_dataframe_checks(estimator): + yield check + if not legacy: return @@ -522,7 +530,7 @@ def _should_be_skipped_or_marked(estimator, check): return False, "placeholder reason that will never be used" -def parametrize_with_checks(estimators, legacy=True): +def parametrize_with_checks(estimators, *, dataframe: bool = True, legacy: bool = True): """Pytest specific decorator for parametrizing estimator checks. Checks are categorised into the following groups: @@ -547,6 +555,13 @@ def parametrize_with_checks(estimators, legacy=True): .. versionadded:: 0.24 + dataframe : bool (default=True) + Whether to included checks related to inspecting feature counts and feature + names. Theese checks might include `polars` or `pandas` to be installed, and are + automatically skipped if otherwise. + + .. versionadded:: 1.6 + legacy : bool (default=True) Whether to include legacy checks. @@ -585,7 +600,9 @@ def parametrize_with_checks(estimators, legacy=True): def checks_generator(): for estimator in estimators: name = type(estimator).__name__ - for check in _yield_all_checks(estimator, legacy=legacy): + for check in _yield_all_checks( + estimator, dataframe=dataframe, legacy=legacy + ): check = partial(check, name) yield _maybe_mark_xfail(estimator, check, pytest) @@ -594,7 +611,9 @@ def checks_generator(): ) -def check_estimator(estimator=None, generate_only=False, legacy=True): +def check_estimator( + estimator=None, generate_only=False, *, dataframe: bool = True, legacy: bool = True +): """Check if estimator adheres to scikit-learn conventions. This function will run an extensive test-suite for input validation, @@ -635,6 +654,13 @@ def check_estimator(estimator=None, generate_only=False, legacy=True): .. versionadded:: 0.22 + dataframe : bool (default=True) + Whether to included checks related to inspecting feature counts and feature + names. Theese checks might include `polars` or `pandas` to be installed, and are + automatically skipped if otherwise. + + .. versionadded:: 1.6 + legacy : bool (default=True) Whether to include legacy checks. @@ -669,7 +695,7 @@ def check_estimator(estimator=None, generate_only=False, legacy=True): name = type(estimator).__name__ def checks_generator(): - for check in _yield_all_checks(estimator, legacy=legacy): + for check in _yield_all_checks(estimator, dataframe=dataframe, legacy=legacy): check = _maybe_skip(estimator, check) yield estimator, partial(check, name) @@ -3995,75 +4021,6 @@ def check_requires_y_none(name, estimator_orig): raise ve -@ignore_warnings(category=FutureWarning) -def check_n_features_in_after_fitting(name, estimator_orig): - # Make sure that n_features_in are checked after fitting - tags = _safe_tags(estimator_orig) - - is_supported_X_types = ( - "2darray" in tags["X_types"] or "categorical" in tags["X_types"] - ) - - if not is_supported_X_types or tags["no_validation"]: - return - - rng = np.random.RandomState(0) - - estimator = clone(estimator_orig) - set_random_state(estimator) - if "warm_start" in estimator.get_params(): - estimator.set_params(warm_start=False) - - n_samples = 10 - X = rng.normal(size=(n_samples, 4)) - X = _enforce_estimator_tags_X(estimator, X) - - if is_regressor(estimator): - y = rng.normal(size=n_samples) - else: - y = rng.randint(low=0, high=2, size=n_samples) - y = _enforce_estimator_tags_y(estimator, y) - - estimator.fit(X, y) - assert estimator.n_features_in_ == X.shape[1] - - # check methods will check n_features_in_ - check_methods = [ - "predict", - "transform", - "decision_function", - "predict_proba", - "score", - ] - X_bad = X[:, [1]] - - msg = f"X has 1 features, but \\w+ is expecting {X.shape[1]} features as input" - for method in check_methods: - if not hasattr(estimator, method): - continue - - callable_method = getattr(estimator, method) - if method == "score": - callable_method = partial(callable_method, y=y) - - with raises(ValueError, match=msg): - callable_method(X_bad) - - # partial_fit will check in the second call - if not hasattr(estimator, "partial_fit"): - return - - estimator = clone(estimator_orig) - if is_classifier(estimator): - estimator.partial_fit(X, y, classes=np.unique(y)) - else: - estimator.partial_fit(X, y) - assert estimator.n_features_in_ == X.shape[1] - - with raises(ValueError, match=msg): - estimator.partial_fit(X_bad, y) - - def check_estimator_get_tags_default_keys(name, estimator_orig): # check that if _get_tags is implemented, it contains all keys from # _DEFAULT_KEYS @@ -4793,3 +4750,74 @@ def check_inplace_ensure_writeable(name, estimator_orig): assert not X.flags.writeable assert_allclose(X, X_copy) + + +# Dataframe / Feature Names inspection tests +# ========================================== +@ignore_warnings(category=FutureWarning) +def check_n_features_in_after_fitting(name, estimator_orig): + # Make sure that n_features_in are checked after fitting + tags = _safe_tags(estimator_orig) + + is_supported_X_types = ( + "2darray" in tags["X_types"] or "categorical" in tags["X_types"] + ) + + if not is_supported_X_types or tags["no_validation"]: + return + + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + if "warm_start" in estimator.get_params(): + estimator.set_params(warm_start=False) + + n_samples = 10 + X = rng.normal(size=(n_samples, 4)) + X = _enforce_estimator_tags_X(estimator, X) + + if is_regressor(estimator): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + estimator.fit(X, y) + assert estimator.n_features_in_ == X.shape[1] + + # check methods will check n_features_in_ + check_methods = [ + "predict", + "transform", + "decision_function", + "predict_proba", + "score", + ] + X_bad = X[:, [1]] + + msg = f"X has 1 features, but \\w+ is expecting {X.shape[1]} features as input" + for method in check_methods: + if not hasattr(estimator, method): + continue + + callable_method = getattr(estimator, method) + if method == "score": + callable_method = partial(callable_method, y=y) + + with raises(ValueError, match=msg): + callable_method(X_bad) + + # partial_fit will check in the second call + if not hasattr(estimator, "partial_fit"): + return + + estimator = clone(estimator_orig) + if is_classifier(estimator): + estimator.partial_fit(X, y, classes=np.unique(y)) + else: + estimator.partial_fit(X, y) + assert estimator.n_features_in_ == X.shape[1] + + with raises(ValueError, match=msg): + estimator.partial_fit(X_bad, y) From 3b247ace91565ed632afe25166927aa7b11de82d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 25 Aug 2024 22:22:28 +0200 Subject: [PATCH 07/11] move check_pandas_column_name_consistency --- sklearn/tests/test_common.py | 46 --- .../utils/_test_common/instance_generator.py | 22 ++ sklearn/utils/estimator_checks.py | 306 ++++++++++-------- sklearn/utils/tests/test_estimator_checks.py | 10 +- 4 files changed, 190 insertions(+), 194 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 81f1218b5520f..4bdf09f6f0bea 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -64,7 +64,6 @@ ) from sklearn.utils.estimator_checks import ( check_class_weight_balanced_linear_classifier, - check_dataframe_column_names_consistency, check_estimator, check_get_feature_names_out_error, check_global_output_transform_pandas, @@ -284,51 +283,6 @@ def test_valid_tag_types(estimator): assert isinstance(tag, correct_tags) -def _estimators_that_predict_in_fit(): - for estimator in _tested_estimators(): - est_params = set(estimator.get_params()) - if "oob_score" in est_params: - yield estimator.set_params(oob_score=True, bootstrap=True) - elif "early_stopping" in est_params: - est = estimator.set_params(early_stopping=True, n_iter_no_change=1) - if est.__class__.__name__ in {"MLPClassifier", "MLPRegressor"}: - # TODO: FIX MLP to not check validation set during MLP - yield pytest.param( - est, marks=pytest.mark.xfail(msg="MLP still validates in fit") - ) - else: - yield est - elif "n_iter_no_change" in est_params: - yield estimator.set_params(n_iter_no_change=1) - - -# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that -# delegates validation to a base estimator, the check is testing that the base estimator -# is checking for column name consistency. -column_name_estimators = list( - chain( - _tested_estimators(), - [make_pipeline(LogisticRegression(C=1))], - list(_generate_search_cv_instances()), - _estimators_that_predict_in_fit(), - ) -) - - -@pytest.mark.parametrize( - "estimator", column_name_estimators, ids=_get_check_estimator_ids -) -def test_pandas_column_name_consistency(estimator): - _set_checking_parameters(estimator) - with ignore_warnings(category=(FutureWarning)): - with warnings.catch_warnings(record=True) as record: - check_dataframe_column_names_consistency( - estimator.__class__.__name__, estimator - ) - for warning in record: - assert "was fitted without feature names" not in str(warning.message) - - # TODO: As more modules support get_feature_names_out they should be removed # from this list to be tested GET_FEATURES_OUT_MODULES_TO_IGNORE = [ diff --git a/sklearn/utils/_test_common/instance_generator.py b/sklearn/utils/_test_common/instance_generator.py index c8887ad524dd4..3754eead9da59 100644 --- a/sklearn/utils/_test_common/instance_generator.py +++ b/sklearn/utils/_test_common/instance_generator.py @@ -272,6 +272,28 @@ TweedieRegressor: dict(max_iter=5), } +SINGLE_TEST_PARAMS = { + "check_pandas_column_name_consistency": { + BaggingClassifier: dict(oob_score=True), + BaggingRegressor: dict(oob_score=True), + ExtraTreesClassifier: dict(bootstrap=True, oob_score=True), + ExtraTreesRegressor: dict(bootstrap=True, oob_score=True), + GradientBoostingClassifier: dict(n_iter_no_change=1), + GradientBoostingRegressor: dict(n_iter_no_change=1), + HistGradientBoostingClassifier: dict(early_stopping=True, n_iter_no_change=1), + HistGradientBoostingRegressor: dict(early_stopping=True, n_iter_no_change=1), + MLPClassifier: dict(early_stopping=True, n_iter_no_change=1), + MLPRegressor: dict(early_stopping=True, n_iter_no_change=1), + PassiveAggressiveClassifier: dict(early_stopping=True, n_iter_no_change=1), + PassiveAggressiveRegressor: dict(early_stopping=True, n_iter_no_change=1), + Perceptron: dict(early_stopping=True, n_iter_no_change=1), + RandomForestClassifier: dict(oob_score=True), + RandomForestRegressor: dict(oob_score=True), + SGDClassifier: dict(early_stopping=True, n_iter_no_change=1), + SGDRegressor: dict(early_stopping=True, n_iter_no_change=1), + } +} + def _set_checking_parameters(estimator): # set parameters to speed up some estimators and diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 94a421855a944..97f86511bbdee 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -60,8 +60,10 @@ ) from ._test_common.instance_generator import ( CROSS_DECOMPOSITION, + SINGLE_TEST_PARAMS, _construct_instance, _get_check_estimator_ids, + _set_checking_parameters, ) from ._testing import ( SkipTest, @@ -91,6 +93,7 @@ def _yield_api_checks(estimator): def _yield_dataframe_checks(estimator): yield check_n_features_in_after_fitting + yield check_pandas_column_name_consistency def _yield_checks(estimator): @@ -3828,149 +3831,6 @@ def check_estimator_get_tags_default_keys(name, estimator_orig): ) -def check_dataframe_column_names_consistency(name, estimator_orig): - try: - import pandas as pd - except ImportError: - raise SkipTest( - "pandas is not installed: not checking column name consistency for pandas" - ) - - tags = _safe_tags(estimator_orig) - is_supported_X_types = ( - "2darray" in tags["X_types"] or "categorical" in tags["X_types"] - ) - - if not is_supported_X_types or tags["no_validation"]: - return - - rng = np.random.RandomState(0) - - estimator = clone(estimator_orig) - set_random_state(estimator) - - X_orig = rng.normal(size=(150, 8)) - - X_orig = _enforce_estimator_tags_X(estimator, X_orig) - n_samples, n_features = X_orig.shape - - names = np.array([f"col_{i}" for i in range(n_features)]) - X = pd.DataFrame(X_orig, columns=names, copy=False) - - if is_regressor(estimator): - y = rng.normal(size=n_samples) - else: - y = rng.randint(low=0, high=2, size=n_samples) - y = _enforce_estimator_tags_y(estimator, y) - - # Check that calling `fit` does not raise any warnings about feature names. - with warnings.catch_warnings(): - warnings.filterwarnings( - "error", - message="X does not have valid feature names", - category=UserWarning, - module="sklearn", - ) - estimator.fit(X, y) - - if not hasattr(estimator, "feature_names_in_"): - raise ValueError( - "Estimator does not have a feature_names_in_ " - "attribute after fitting with a dataframe" - ) - assert isinstance(estimator.feature_names_in_, np.ndarray) - assert estimator.feature_names_in_.dtype == object - assert_array_equal(estimator.feature_names_in_, names) - - # Only check sklearn estimators for feature_names_in_ in docstring - module_name = estimator_orig.__module__ - if ( - module_name.startswith("sklearn.") - and not ("test_" in module_name or module_name.endswith("_testing")) - and ("feature_names_in_" not in (estimator_orig.__doc__)) - ): - raise ValueError( - f"Estimator {name} does not document its feature_names_in_ attribute" - ) - - check_methods = [] - for method in ( - "predict", - "transform", - "decision_function", - "predict_proba", - "score", - "score_samples", - "predict_log_proba", - ): - if not hasattr(estimator, method): - continue - - callable_method = getattr(estimator, method) - if method == "score": - callable_method = partial(callable_method, y=y) - check_methods.append((method, callable_method)) - - for _, method in check_methods: - with warnings.catch_warnings(): - warnings.filterwarnings( - "error", - message="X does not have valid feature names", - category=UserWarning, - module="sklearn", - ) - method(X) # works without UserWarning for valid features - - invalid_names = [ - (names[::-1], "Feature names must be in the same order as they were in fit."), - ( - [f"another_prefix_{i}" for i in range(n_features)], - ( - "Feature names unseen at fit time:\n- another_prefix_0\n-" - " another_prefix_1\n" - ), - ), - ( - names[:3], - f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n", - ), - ] - params = { - key: value - for key, value in estimator.get_params().items() - if "early_stopping" in key - } - early_stopping_enabled = any(value is True for value in params.values()) - - for invalid_name, additional_message in invalid_names: - X_bad = pd.DataFrame(X, columns=invalid_name, copy=False) - - expected_msg = re.escape( - "The feature names should match those that were passed during fit.\n" - f"{additional_message}" - ) - for name, method in check_methods: - with raises( - ValueError, match=expected_msg, err_msg=f"{name} did not raise" - ): - method(X_bad) - - # partial_fit checks on second call - # Do not call partial fit if early_stopping is on - if not hasattr(estimator, "partial_fit") or early_stopping_enabled: - continue - - estimator = clone(estimator_orig) - if is_classifier(estimator): - classes = np.unique(y) - estimator.partial_fit(X, y, classes=classes) - else: - estimator.partial_fit(X, y) - - with raises(ValueError, match=expected_msg): - estimator.partial_fit(X_bad, y) - - def check_transformer_get_feature_names_out(name, transformer_orig): tags = transformer_orig._get_tags() if "2darray" not in tags["X_types"] or tags["no_validation"]: @@ -4613,3 +4473,163 @@ def check_n_features_in_after_fitting(name, estimator_orig): with raises(ValueError, match=msg): estimator.partial_fit(X_bad, y) + + +def _check_dataframe_column_names_consistency(name, estimator_orig): + try: + import pandas as pd + except ImportError: + raise SkipTest( + "pandas is not installed: not checking column name consistency for pandas" + ) + + tags = _safe_tags(estimator_orig) + is_supported_X_types = ( + "2darray" in tags["X_types"] or "categorical" in tags["X_types"] + ) + + if not is_supported_X_types or tags["no_validation"]: + return + + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + + X_orig = rng.normal(size=(150, 8)) + + X_orig = _enforce_estimator_tags_X(estimator, X_orig) + n_samples, n_features = X_orig.shape + + names = np.array([f"col_{i}" for i in range(n_features)]) + X = pd.DataFrame(X_orig, columns=names, copy=False) + + if is_regressor(estimator): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + # Check that calling `fit` does not raise any warnings about feature names. + with warnings.catch_warnings(): + warnings.filterwarnings( + "error", + message="X does not have valid feature names", + category=UserWarning, + module="sklearn", + ) + estimator.fit(X, y) + + if not hasattr(estimator, "feature_names_in_"): + raise ValueError( + "Estimator does not have a feature_names_in_ " + "attribute after fitting with a dataframe" + ) + assert isinstance(estimator.feature_names_in_, np.ndarray) + assert estimator.feature_names_in_.dtype == object + assert_array_equal(estimator.feature_names_in_, names) + + # Only check sklearn estimators for feature_names_in_ in docstring + module_name = estimator_orig.__module__ + if ( + module_name.startswith("sklearn.") + and not ("test_" in module_name or module_name.endswith("_testing")) + and ("feature_names_in_" not in (estimator_orig.__doc__)) + ): + raise ValueError( + f"Estimator {name} does not document its feature_names_in_ attribute" + ) + + check_methods = [] + for method in ( + "predict", + "transform", + "decision_function", + "predict_proba", + "score", + "score_samples", + "predict_log_proba", + ): + if not hasattr(estimator, method): + continue + + callable_method = getattr(estimator, method) + if method == "score": + callable_method = partial(callable_method, y=y) + check_methods.append((method, callable_method)) + + for _, method in check_methods: + with warnings.catch_warnings(): + warnings.filterwarnings( + "error", + message="X does not have valid feature names", + category=UserWarning, + module="sklearn", + ) + method(X) # works without UserWarning for valid features + + invalid_names = [ + (names[::-1], "Feature names must be in the same order as they were in fit."), + ( + [f"another_prefix_{i}" for i in range(n_features)], + ( + "Feature names unseen at fit time:\n- another_prefix_0\n-" + " another_prefix_1\n" + ), + ), + ( + names[:3], + f"Feature names seen at fit time, yet now missing:\n- {min(names[3:])}\n", + ), + ] + params = { + key: value + for key, value in estimator.get_params().items() + if "early_stopping" in key + } + early_stopping_enabled = any(value is True for value in params.values()) + + for invalid_name, additional_message in invalid_names: + X_bad = pd.DataFrame(X, columns=invalid_name, copy=False) + + expected_msg = re.escape( + "The feature names should match those that were passed during fit.\n" + f"{additional_message}" + ) + for name, method in check_methods: + with raises( + ValueError, match=expected_msg, err_msg=f"{name} did not raise" + ): + method(X_bad) + + # partial_fit checks on second call + # Do not call partial fit if early_stopping is on + if not hasattr(estimator, "partial_fit") or early_stopping_enabled: + continue + + estimator = clone(estimator_orig) + if is_classifier(estimator): + classes = np.unique(y) + estimator.partial_fit(X, y, classes=classes) + else: + estimator.partial_fit(X, y) + + with raises(ValueError, match=expected_msg): + estimator.partial_fit(X_bad, y) + + +def check_pandas_column_name_consistency(name, estimator): + _set_checking_parameters(estimator) + + # NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator + # that delegates validation to a base estimator, the check is testing that the base + # estimator is checking for column name consistency. + test_estimator_params = SINGLE_TEST_PARAMS[check_pandas_column_name_consistency] + if type(estimator) in test_estimator_params: + estimator.set_params(test_estimator_params[type(estimator)]) + + with ignore_warnings(category=(FutureWarning)): + with warnings.catch_warnings(record=True) as record: + _check_dataframe_column_names_consistency(name, estimator) + for warning in record: + assert "was fitted without feature names" not in str(warning.message) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 46154e5bdf3d8..f561c852e66eb 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -40,6 +40,7 @@ raises, ) from sklearn.utils.estimator_checks import ( + _check_dataframe_column_names_consistency, _NotAnArray, _yield_all_checks, check_array_api_input, @@ -48,7 +49,6 @@ check_classifiers_multilabel_output_format_decision_function, check_classifiers_multilabel_output_format_predict, check_classifiers_multilabel_output_format_predict_proba, - check_dataframe_column_names_consistency, check_decision_proba_consistency, check_estimator, check_estimator_get_tags_default_keys, @@ -864,17 +864,17 @@ def test_check_estimator_get_tags_default_keys(): def test_check_dataframe_column_names_consistency(): err_msg = "Estimator does not have a feature_names_in_" with raises(ValueError, match=err_msg): - check_dataframe_column_names_consistency("estimator_name", BaseBadClassifier()) - check_dataframe_column_names_consistency("estimator_name", PartialFitChecksName()) + _check_dataframe_column_names_consistency("estimator_name", BaseBadClassifier()) + _check_dataframe_column_names_consistency("estimator_name", PartialFitChecksName()) lr = LogisticRegression() - check_dataframe_column_names_consistency(lr.__class__.__name__, lr) + _check_dataframe_column_names_consistency(lr.__class__.__name__, lr) lr.__doc__ = "Docstring that does not document the estimator's attributes" err_msg = ( "Estimator LogisticRegression does not document its feature_names_in_ attribute" ) with raises(ValueError, match=err_msg): - check_dataframe_column_names_consistency(lr.__class__.__name__, lr) + _check_dataframe_column_names_consistency(lr.__class__.__name__, lr) class _BaseMultiLabelClassifierMock(ClassifierMixin, BaseEstimator): From e5dee7ae8fa37a8a4deeba923e390a3a4b1bbed1 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 5 Sep 2024 12:08:04 +0200 Subject: [PATCH 08/11] key error --- sklearn/utils/estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index bd81192f0021a..442eb01cb485c 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4644,9 +4644,9 @@ def check_pandas_column_name_consistency(name, estimator): # NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator # that delegates validation to a base estimator, the check is testing that the base # estimator is checking for column name consistency. - test_estimator_params = SINGLE_TEST_PARAMS[check_pandas_column_name_consistency] + test_estimator_params = SINGLE_TEST_PARAMS["check_pandas_column_name_consistency"] if type(estimator) in test_estimator_params: - estimator.set_params(test_estimator_params[type(estimator)]) + estimator.set_params(**test_estimator_params[type(estimator)]) with ignore_warnings(category=(FutureWarning)): with warnings.catch_warnings(record=True) as record: From 0565c11c6b4ffbb77289a3d44f640d254e0aa08f Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Thu, 5 Sep 2024 12:43:17 +0200 Subject: [PATCH 09/11] fix side effect --- sklearn/tree/_classes.py | 2 +- sklearn/utils/estimator_checks.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index abfb836a6ec27..0ce9b6ed3386f 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -1957,5 +1957,5 @@ def __sklearn_tags__(self): "friedman_mse", "poisson", } - tags.input_tags.allow_nan: allow_nan + tags.input_tags.allow_nan = allow_nan return tags diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 442eb01cb485c..5e243654f389d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -360,7 +360,7 @@ def _yield_all_checks(estimator, dataframe: bool, legacy: bool): yield check if not legacy: - return + return # pragma: no cover for check in _yield_checks(estimator): yield check @@ -4550,11 +4550,11 @@ def _check_dataframe_column_names_consistency(name, estimator_orig): assert_array_equal(estimator.feature_names_in_, names) # Only check sklearn estimators for feature_names_in_ in docstring - module_name = estimator_orig.__module__ + module_name = estimator.__module__ if ( module_name.startswith("sklearn.") and not ("test_" in module_name or module_name.endswith("_testing")) - and ("feature_names_in_" not in (estimator_orig.__doc__)) + and ("feature_names_in_" not in (estimator.__doc__)) ): raise ValueError( f"Estimator {name} does not document its feature_names_in_ attribute" @@ -4638,7 +4638,8 @@ def _check_dataframe_column_names_consistency(name, estimator_orig): estimator.partial_fit(X_bad, y) -def check_pandas_column_name_consistency(name, estimator): +def check_pandas_column_name_consistency(name, estimator_orig): + estimator = clone(estimator_orig) _set_checking_parameters(estimator) # NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator From f844822ad31d4364bc867c0812e12a4bb7756e93 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Fri, 6 Sep 2024 09:44:14 +0200 Subject: [PATCH 10/11] changelog --- doc/whats_new/v1.6.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 2021e9bb8ccc0..bf8cf77b8b1ea 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -344,6 +344,12 @@ Changelog calling :func:`utils.validation.check_non_negative`. :pr:`29540` by :user:`Tamara Atanasoska `. +- |Enhancement| :func:`utils.estimator_checks.parametrize_with_checks` and + :func:`utils.estimator_checks.check_estimator` now have started putting tests into + categories which can be enabled / disabled using their `dataframe` and `legacy` + parameters. + :pr:`29699`, :pr:`29713` by `Adrin Jalali`_. + - |API| the `assert_all_finite` parameter of functions :func:`utils.check_array`, :func:`utils.check_X_y`, :func:`utils.as_float_array` is renamed into `ensure_all_finite`. `force_all_finite` will be removed in 1.8. From bee4ad79b2fc4391e28d03109c557f36c68c33c3 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Sun, 8 Sep 2024 13:07:36 +0200 Subject: [PATCH 11/11] more fixes --- sklearn/neighbors/_lof.py | 8 ++++---- sklearn/utils/tests/test_estimator_checks.py | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/sklearn/neighbors/_lof.py b/sklearn/neighbors/_lof.py index c05a4f60773b0..52de9e2fe22b9 100644 --- a/sklearn/neighbors/_lof.py +++ b/sklearn/neighbors/_lof.py @@ -7,10 +7,9 @@ import numpy as np from ..base import OutlierMixin, _fit_context -from ..utils import check_array from ..utils._param_validation import Interval, StrOptions from ..utils.metaestimators import available_if -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, validate_data from ._base import KNeighborsMixin, NeighborsBase __all__ = ["LocalOutlierFactor"] @@ -471,13 +470,14 @@ def score_samples(self, X): The lower, the more abnormal. """ check_is_fitted(self) - X = check_array(X, accept_sparse="csr") + # not replacing X since we need to pass raw X to kneighbors + X_validated = validate_data(self, X, reset=False, accept_sparse="csr") distances_X, neighbors_indices_X = self.kneighbors( X, n_neighbors=self.n_neighbors_ ) - if X.dtype == np.float32: + if X_validated.dtype == np.float32: distances_X = distances_X.astype(X.dtype, copy=False) X_lrd = self._local_reachability_density( diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index bf70a2ec19b94..6d0b875ae118c 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -1204,7 +1204,7 @@ def test_non_deterministic_estimator_skip_tests(): # check estimators with non_deterministic tag set to True # will skip certain tests, refer to issue #22313 for details for est in [MinimalTransformer, MinimalRegressor, MinimalClassifier]: - all_tests = list(_yield_all_checks(est(), legacy=True)) + all_tests = list(_yield_all_checks(est(), dataframe=True, legacy=True)) assert check_methods_sample_order_invariance in all_tests assert check_methods_subset_invariance in all_tests @@ -1214,7 +1214,7 @@ def __sklearn_tags__(self): tags.non_deterministic = True return tags - all_tests = list(_yield_all_checks(Estimator(), legacy=True)) + all_tests = list(_yield_all_checks(Estimator(), dataframe=True, legacy=True)) assert check_methods_sample_order_invariance not in all_tests assert check_methods_subset_invariance not in all_tests @@ -1283,14 +1283,14 @@ def test_decision_proba_tie_ranking(): check_decision_proba_consistency("SGDClassifier", estimator) -def test_yield_all_checks_legacy(): - # Test that _yield_all_checks with legacy=True returns more checks. +def test_yield_all_checks_api(): + # Test that _yield_all_checks with API only returns less checks. estimator = MinimalClassifier() - legacy_checks = list(_yield_all_checks(estimator, legacy=True)) - non_legacy_checks = list(_yield_all_checks(estimator, legacy=False)) + all_checks = list(_yield_all_checks(estimator, dataframe=True, legacy=True)) + api_only_checks = list(_yield_all_checks(estimator, dataframe=False, legacy=False)) - assert len(legacy_checks) > len(non_legacy_checks) + assert len(all_checks) > len(api_only_checks) def get_check_name(check): try: @@ -1299,9 +1299,9 @@ def get_check_name(check): return check.func.__name__ # Check that all non-legacy checks are included in legacy checks - non_legacy_check_names = {get_check_name(check) for check in non_legacy_checks} - legacy_check_names = {get_check_name(check) for check in legacy_checks} - assert non_legacy_check_names.issubset(legacy_check_names) + api_only_check_names = {get_check_name(check) for check in api_only_checks} + all_check_names = {get_check_name(check) for check in all_checks} + assert api_only_check_names.issubset(all_check_names) def test_check_estimator_cloneable_error():