Skip to content

Commit 3cb7b58

Browse files
TST refactor instance generation and parameter setting (#29702)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
1 parent 1c3dcb4 commit 3cb7b58

File tree

10 files changed

+471
-311
lines changed

10 files changed

+471
-311
lines changed

doc/sphinxext/allow_nan_estimators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from docutils.parsers.rst import Directive
55

66
from sklearn.utils import all_estimators
7+
from sklearn.utils._test_common.instance_generator import _construct_instance
78
from sklearn.utils._testing import SkipTest
8-
from sklearn.utils.estimator_checks import _construct_instance
99

1010

1111
class AllowNanEstimators(Directive):

sklearn/decomposition/tests/test_pca.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
yield_namespace_device_dtype_combinations,
1818
)
1919
from sklearn.utils._array_api import device as array_device
20+
from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids
2021
from sklearn.utils._testing import _array_api_for_tests, assert_allclose
2122
from sklearn.utils.estimator_checks import (
22-
_get_check_estimator_ids,
2323
check_array_api_input_and_values,
2424
)
2525
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS

sklearn/linear_model/tests/test_ridge.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
yield_namespace_device_dtype_combinations,
4949
yield_namespaces,
5050
)
51+
from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids
5152
from sklearn.utils._testing import (
5253
assert_allclose,
5354
assert_almost_equal,
@@ -57,7 +58,6 @@
5758
)
5859
from sklearn.utils.estimator_checks import (
5960
_array_api_for_tests,
60-
_get_check_estimator_ids,
6161
check_array_api_input_and_values,
6262
)
6363
from sklearn.utils.fixes import (

sklearn/preprocessing/tests/test_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from sklearn.utils._array_api import (
4141
yield_namespace_device_dtype_combinations,
4242
)
43+
from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids
4344
from sklearn.utils._testing import (
4445
_convert_container,
4546
assert_allclose,
@@ -51,7 +52,6 @@
5152
skip_if_32bit,
5253
)
5354
from sklearn.utils.estimator_checks import (
54-
_get_check_estimator_ids,
5555
check_array_api_input_and_values,
5656
)
5757
from sklearn.utils.fixes import (

sklearn/tests/test_common.py

+13-91
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import warnings
1212
from functools import partial
1313
from inspect import isgenerator, signature
14-
from itertools import chain, product
14+
from itertools import chain
1515

1616
import numpy as np
1717
import pytest
@@ -26,32 +26,24 @@
2626
MeanShift,
2727
SpectralClustering,
2828
)
29-
from sklearn.compose import ColumnTransformer
3029
from sklearn.datasets import make_blobs
31-
from sklearn.decomposition import PCA
3230
from sklearn.exceptions import ConvergenceWarning, FitFailedWarning
33-
34-
# make it possible to discover experimental estimators when calling `all_estimators`
3531
from sklearn.experimental import (
3632
enable_halving_search_cv, # noqa
3733
enable_iterative_imputer, # noqa
3834
)
39-
from sklearn.linear_model import LogisticRegression, Ridge
35+
36+
# make it possible to discover experimental estimators when calling `all_estimators`
37+
from sklearn.linear_model import LogisticRegression
4038
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding
41-
from sklearn.model_selection import (
42-
GridSearchCV,
43-
HalvingGridSearchCV,
44-
HalvingRandomSearchCV,
45-
RandomizedSearchCV,
46-
)
4739
from sklearn.neighbors import (
4840
KNeighborsClassifier,
4941
KNeighborsRegressor,
5042
LocalOutlierFactor,
5143
RadiusNeighborsClassifier,
5244
RadiusNeighborsRegressor,
5345
)
54-
from sklearn.pipeline import Pipeline, make_pipeline
46+
from sklearn.pipeline import make_pipeline
5547
from sklearn.preprocessing import (
5648
FunctionTransformer,
5749
MinMaxScaler,
@@ -61,15 +53,19 @@
6153
from sklearn.semi_supervised import LabelPropagation, LabelSpreading
6254
from sklearn.utils import all_estimators
6355
from sklearn.utils._tags import _DEFAULT_TAGS, _safe_tags
56+
from sklearn.utils._test_common.instance_generator import (
57+
_generate_column_transformer_instances,
58+
_generate_pipeline,
59+
_generate_search_cv_instances,
60+
_get_check_estimator_ids,
61+
_set_checking_parameters,
62+
_tested_estimators,
63+
)
6464
from sklearn.utils._testing import (
6565
SkipTest,
6666
ignore_warnings,
67-
set_random_state,
6867
)
6968
from sklearn.utils.estimator_checks import (
70-
_construct_instance,
71-
_get_check_estimator_ids,
72-
_set_checking_parameters,
7369
check_dataframe_column_names_consistency,
7470
check_estimator,
7571
check_get_feature_names_out_error,
@@ -137,26 +133,6 @@ def test_get_check_estimator_ids(val, expected):
137133
assert _get_check_estimator_ids(val) == expected
138134

139135

140-
def _tested_estimators(type_filter=None):
141-
for name, Estimator in all_estimators(type_filter=type_filter):
142-
try:
143-
estimator = _construct_instance(Estimator)
144-
except SkipTest:
145-
continue
146-
147-
yield estimator
148-
149-
150-
def _generate_pipeline():
151-
for final_estimator in [Ridge(), LogisticRegression()]:
152-
yield Pipeline(
153-
steps=[
154-
("scaler", StandardScaler()),
155-
("final_estimator", final_estimator),
156-
]
157-
)
158-
159-
160136
@parametrize_with_checks(list(chain(_tested_estimators(), _generate_pipeline())))
161137
def test_estimators(estimator, check, request):
162138
# Common tests for estimator instances
@@ -259,60 +235,6 @@ def test_class_support_removed():
259235
parametrize_with_checks([LogisticRegression])
260236

261237

262-
def _generate_column_transformer_instances():
263-
yield ColumnTransformer(
264-
transformers=[
265-
("trans1", StandardScaler(), [0, 1]),
266-
]
267-
)
268-
269-
270-
def _generate_search_cv_instances():
271-
for SearchCV, (Estimator, param_grid) in product(
272-
[
273-
GridSearchCV,
274-
HalvingGridSearchCV,
275-
RandomizedSearchCV,
276-
HalvingGridSearchCV,
277-
],
278-
[
279-
(Ridge, {"alpha": [0.1, 1.0]}),
280-
(LogisticRegression, {"C": [0.1, 1.0]}),
281-
],
282-
):
283-
init_params = signature(SearchCV).parameters
284-
extra_params = (
285-
{"min_resources": "smallest"} if "min_resources" in init_params else {}
286-
)
287-
search_cv = SearchCV(
288-
Estimator(), param_grid, cv=2, error_score="raise", **extra_params
289-
)
290-
set_random_state(search_cv)
291-
yield search_cv
292-
293-
for SearchCV, (Estimator, param_grid) in product(
294-
[
295-
GridSearchCV,
296-
HalvingGridSearchCV,
297-
RandomizedSearchCV,
298-
HalvingRandomSearchCV,
299-
],
300-
[
301-
(Ridge, {"ridge__alpha": [0.1, 1.0]}),
302-
(LogisticRegression, {"logisticregression__C": [0.1, 1.0]}),
303-
],
304-
):
305-
init_params = signature(SearchCV).parameters
306-
extra_params = (
307-
{"min_resources": "smallest"} if "min_resources" in init_params else {}
308-
)
309-
search_cv = SearchCV(
310-
make_pipeline(PCA(), Estimator()), param_grid, cv=2, **extra_params
311-
).set_params(error_score="raise")
312-
set_random_state(search_cv)
313-
yield search_cv
314-
315-
316238
@parametrize_with_checks(list(_generate_search_cv_instances()))
317239
def test_search_cv(estimator, check, request):
318240
# Common tests for SearchCV instances

sklearn/tests/test_docstring_parameters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
from sklearn.linear_model import LogisticRegression
2323
from sklearn.preprocessing import FunctionTransformer
2424
from sklearn.utils import all_estimators
25+
from sklearn.utils._test_common.instance_generator import _construct_instance
2526
from sklearn.utils._testing import (
2627
_get_func_name,
2728
check_docstring_parameters,
2829
ignore_warnings,
2930
)
3031
from sklearn.utils.deprecation import _is_deprecated
3132
from sklearn.utils.estimator_checks import (
32-
_construct_instance,
3333
_enforce_estimator_tags_X,
3434
_enforce_estimator_tags_y,
3535
)
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Authors: The scikit-learn developers
2+
# SPDX-License-Identifier: BSD-3-Clause

0 commit comments

Comments
 (0)