|
11 | 11 | import warnings
|
12 | 12 | from functools import partial
|
13 | 13 | from inspect import isgenerator, signature
|
14 |
| -from itertools import chain, product |
| 14 | +from itertools import chain |
15 | 15 |
|
16 | 16 | import numpy as np
|
17 | 17 | import pytest
|
|
26 | 26 | MeanShift,
|
27 | 27 | SpectralClustering,
|
28 | 28 | )
|
29 |
| -from sklearn.compose import ColumnTransformer |
30 | 29 | from sklearn.datasets import make_blobs
|
31 |
| -from sklearn.decomposition import PCA |
32 | 30 | from sklearn.exceptions import ConvergenceWarning, FitFailedWarning
|
33 |
| - |
34 |
| -# make it possible to discover experimental estimators when calling `all_estimators` |
35 | 31 | from sklearn.experimental import (
|
36 | 32 | enable_halving_search_cv, # noqa
|
37 | 33 | enable_iterative_imputer, # noqa
|
38 | 34 | )
|
39 |
| -from sklearn.linear_model import LogisticRegression, Ridge |
| 35 | + |
| 36 | +# make it possible to discover experimental estimators when calling `all_estimators` |
| 37 | +from sklearn.linear_model import LogisticRegression |
40 | 38 | from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding
|
41 |
| -from sklearn.model_selection import ( |
42 |
| - GridSearchCV, |
43 |
| - HalvingGridSearchCV, |
44 |
| - HalvingRandomSearchCV, |
45 |
| - RandomizedSearchCV, |
46 |
| -) |
47 | 39 | from sklearn.neighbors import (
|
48 | 40 | KNeighborsClassifier,
|
49 | 41 | KNeighborsRegressor,
|
50 | 42 | LocalOutlierFactor,
|
51 | 43 | RadiusNeighborsClassifier,
|
52 | 44 | RadiusNeighborsRegressor,
|
53 | 45 | )
|
54 |
| -from sklearn.pipeline import Pipeline, make_pipeline |
| 46 | +from sklearn.pipeline import make_pipeline |
55 | 47 | from sklearn.preprocessing import (
|
56 | 48 | FunctionTransformer,
|
57 | 49 | MinMaxScaler,
|
|
61 | 53 | from sklearn.semi_supervised import LabelPropagation, LabelSpreading
|
62 | 54 | from sklearn.utils import all_estimators
|
63 | 55 | from sklearn.utils._tags import _DEFAULT_TAGS, _safe_tags
|
| 56 | +from sklearn.utils._test_common.instance_generator import ( |
| 57 | + _generate_column_transformer_instances, |
| 58 | + _generate_pipeline, |
| 59 | + _generate_search_cv_instances, |
| 60 | + _get_check_estimator_ids, |
| 61 | + _set_checking_parameters, |
| 62 | + _tested_estimators, |
| 63 | +) |
64 | 64 | from sklearn.utils._testing import (
|
65 | 65 | SkipTest,
|
66 | 66 | ignore_warnings,
|
67 |
| - set_random_state, |
68 | 67 | )
|
69 | 68 | from sklearn.utils.estimator_checks import (
|
70 |
| - _construct_instance, |
71 |
| - _get_check_estimator_ids, |
72 |
| - _set_checking_parameters, |
73 | 69 | check_dataframe_column_names_consistency,
|
74 | 70 | check_estimator,
|
75 | 71 | check_get_feature_names_out_error,
|
@@ -137,26 +133,6 @@ def test_get_check_estimator_ids(val, expected):
|
137 | 133 | assert _get_check_estimator_ids(val) == expected
|
138 | 134 |
|
139 | 135 |
|
140 |
| -def _tested_estimators(type_filter=None): |
141 |
| - for name, Estimator in all_estimators(type_filter=type_filter): |
142 |
| - try: |
143 |
| - estimator = _construct_instance(Estimator) |
144 |
| - except SkipTest: |
145 |
| - continue |
146 |
| - |
147 |
| - yield estimator |
148 |
| - |
149 |
| - |
150 |
| -def _generate_pipeline(): |
151 |
| - for final_estimator in [Ridge(), LogisticRegression()]: |
152 |
| - yield Pipeline( |
153 |
| - steps=[ |
154 |
| - ("scaler", StandardScaler()), |
155 |
| - ("final_estimator", final_estimator), |
156 |
| - ] |
157 |
| - ) |
158 |
| - |
159 |
| - |
160 | 136 | @parametrize_with_checks(list(chain(_tested_estimators(), _generate_pipeline())))
|
161 | 137 | def test_estimators(estimator, check, request):
|
162 | 138 | # Common tests for estimator instances
|
@@ -259,60 +235,6 @@ def test_class_support_removed():
|
259 | 235 | parametrize_with_checks([LogisticRegression])
|
260 | 236 |
|
261 | 237 |
|
262 |
| -def _generate_column_transformer_instances(): |
263 |
| - yield ColumnTransformer( |
264 |
| - transformers=[ |
265 |
| - ("trans1", StandardScaler(), [0, 1]), |
266 |
| - ] |
267 |
| - ) |
268 |
| - |
269 |
| - |
270 |
| -def _generate_search_cv_instances(): |
271 |
| - for SearchCV, (Estimator, param_grid) in product( |
272 |
| - [ |
273 |
| - GridSearchCV, |
274 |
| - HalvingGridSearchCV, |
275 |
| - RandomizedSearchCV, |
276 |
| - HalvingGridSearchCV, |
277 |
| - ], |
278 |
| - [ |
279 |
| - (Ridge, {"alpha": [0.1, 1.0]}), |
280 |
| - (LogisticRegression, {"C": [0.1, 1.0]}), |
281 |
| - ], |
282 |
| - ): |
283 |
| - init_params = signature(SearchCV).parameters |
284 |
| - extra_params = ( |
285 |
| - {"min_resources": "smallest"} if "min_resources" in init_params else {} |
286 |
| - ) |
287 |
| - search_cv = SearchCV( |
288 |
| - Estimator(), param_grid, cv=2, error_score="raise", **extra_params |
289 |
| - ) |
290 |
| - set_random_state(search_cv) |
291 |
| - yield search_cv |
292 |
| - |
293 |
| - for SearchCV, (Estimator, param_grid) in product( |
294 |
| - [ |
295 |
| - GridSearchCV, |
296 |
| - HalvingGridSearchCV, |
297 |
| - RandomizedSearchCV, |
298 |
| - HalvingRandomSearchCV, |
299 |
| - ], |
300 |
| - [ |
301 |
| - (Ridge, {"ridge__alpha": [0.1, 1.0]}), |
302 |
| - (LogisticRegression, {"logisticregression__C": [0.1, 1.0]}), |
303 |
| - ], |
304 |
| - ): |
305 |
| - init_params = signature(SearchCV).parameters |
306 |
| - extra_params = ( |
307 |
| - {"min_resources": "smallest"} if "min_resources" in init_params else {} |
308 |
| - ) |
309 |
| - search_cv = SearchCV( |
310 |
| - make_pipeline(PCA(), Estimator()), param_grid, cv=2, **extra_params |
311 |
| - ).set_params(error_score="raise") |
312 |
| - set_random_state(search_cv) |
313 |
| - yield search_cv |
314 |
| - |
315 |
| - |
316 | 238 | @parametrize_with_checks(list(_generate_search_cv_instances()))
|
317 | 239 | def test_search_cv(estimator, check, request):
|
318 | 240 | # Common tests for SearchCV instances
|
|
0 commit comments