-
-
Notifications
You must be signed in to change notification settings - Fork 26k
TST remove _required_parameters and improve instance generation #29707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d28c3cb
13a8e27
3474eea
3975f17
d460786
e27edd3
3cabeb7
7eec068
aa4d808
c3f1249
7c4a3b2
02f99a1
1f220c4
f434406
650bb8e
148e5d5
32a6ec6
1ac2a8e
0c7366c
27d315d
912b14a
fd03329
4a99a3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -378,6 +378,9 @@ def __sklearn_tags__(self): | |
"Fail during parameter check since min/max resources requires" | ||
" more samples" | ||
), | ||
"check_estimators_nan_inf": "FIXME", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These estimators are also tested with others now, and these tests fail. Need fixes in another PR. |
||
"check_classifiers_one_label_sample_weights": "FIXME", | ||
"check_fit2d_1feature": "FIXME", | ||
} | ||
) | ||
return tags | ||
|
@@ -668,8 +671,6 @@ class HalvingGridSearchCV(BaseSuccessiveHalving): | |
{'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9} | ||
""" | ||
|
||
_required_parameters = ["estimator", "param_grid"] | ||
|
||
_parameter_constraints: dict = { | ||
**BaseSuccessiveHalving._parameter_constraints, | ||
"param_grid": [dict, list], | ||
|
@@ -1018,8 +1019,6 @@ class HalvingRandomSearchCV(BaseSuccessiveHalving): | |
{'max_depth': None, 'min_samples_split': 10, 'n_estimators': 9} | ||
""" | ||
|
||
_required_parameters = ["estimator", "param_distributions"] | ||
|
||
_parameter_constraints: dict = { | ||
**BaseSuccessiveHalving._parameter_constraints, | ||
"param_distributions": [dict, list], | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,14 +26,15 @@ | |
MeanShift, | ||
SpectralClustering, | ||
) | ||
from sklearn.compose import ColumnTransformer | ||
from sklearn.datasets import make_blobs | ||
from sklearn.exceptions import ConvergenceWarning, FitFailedWarning | ||
|
||
# make it possible to discover experimental estimators when calling `all_estimators` | ||
from sklearn.experimental import ( | ||
enable_halving_search_cv, # noqa | ||
enable_iterative_imputer, # noqa | ||
) | ||
|
||
# make it possible to discover experimental estimators when calling `all_estimators` | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding | ||
from sklearn.neighbors import ( | ||
|
@@ -43,7 +44,7 @@ | |
RadiusNeighborsClassifier, | ||
RadiusNeighborsRegressor, | ||
) | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.pipeline import FeatureUnion, make_pipeline | ||
from sklearn.preprocessing import ( | ||
FunctionTransformer, | ||
MinMaxScaler, | ||
|
@@ -54,11 +55,9 @@ | |
from sklearn.utils import all_estimators | ||
from sklearn.utils._tags import get_tags | ||
from sklearn.utils._test_common.instance_generator import ( | ||
_generate_column_transformer_instances, | ||
_generate_pipeline, | ||
_generate_search_cv_instances, | ||
_get_check_estimator_ids, | ||
_set_checking_parameters, | ||
_tested_estimators, | ||
) | ||
from sklearn.utils._testing import ( | ||
|
@@ -139,7 +138,6 @@ def test_estimators(estimator, check, request): | |
with ignore_warnings( | ||
category=(FutureWarning, ConvergenceWarning, UserWarning, LinAlgWarning) | ||
): | ||
_set_checking_parameters(estimator) | ||
check(estimator) | ||
|
||
|
||
|
@@ -285,7 +283,6 @@ def check_field_types(tags, defaults): | |
"estimator", _tested_estimators(), ids=_get_check_estimator_ids | ||
) | ||
def test_check_n_features_in_after_fitting(estimator): | ||
_set_checking_parameters(estimator) | ||
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) | ||
|
||
|
||
|
@@ -324,7 +321,8 @@ def _estimators_that_predict_in_fit(): | |
"estimator", column_name_estimators, ids=_get_check_estimator_ids | ||
) | ||
def test_pandas_column_name_consistency(estimator): | ||
_set_checking_parameters(estimator) | ||
if isinstance(estimator, ColumnTransformer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this and |
||
pytest.skip("ColumnTransformer is not tested here") | ||
with ignore_warnings(category=(FutureWarning)): | ||
with warnings.catch_warnings(record=True) as record: | ||
check_dataframe_column_names_consistency( | ||
|
@@ -360,7 +358,6 @@ def _include_in_get_feature_names_out_check(transformer): | |
"transformer", GET_FEATURES_OUT_ESTIMATORS, ids=_get_check_estimator_ids | ||
) | ||
def test_transformers_get_feature_names_out(transformer): | ||
_set_checking_parameters(transformer) | ||
|
||
with ignore_warnings(category=(FutureWarning)): | ||
check_transformer_get_feature_names_out( | ||
|
@@ -381,7 +378,6 @@ def test_transformers_get_feature_names_out(transformer): | |
) | ||
def test_estimators_get_feature_names_out_error(estimator): | ||
estimator_name = estimator.__class__.__name__ | ||
_set_checking_parameters(estimator) | ||
check_get_feature_names_out_error(estimator_name, estimator) | ||
|
||
|
||
|
@@ -409,14 +405,14 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): | |
chain( | ||
_tested_estimators(), | ||
_generate_pipeline(), | ||
_generate_column_transformer_instances(), | ||
_generate_search_cv_instances(), | ||
), | ||
ids=_get_check_estimator_ids, | ||
) | ||
def test_check_param_validation(estimator): | ||
if isinstance(estimator, FeatureUnion): | ||
pytest.skip("FeatureUnion is not tested here") | ||
name = estimator.__class__.__name__ | ||
_set_checking_parameters(estimator) | ||
check_param_validation(name, estimator) | ||
|
||
|
||
|
@@ -481,7 +477,6 @@ def test_set_output_transform(estimator): | |
f"Skipping check_set_output_transform for {name}: Does not support" | ||
" set_output API" | ||
) | ||
_set_checking_parameters(estimator) | ||
with ignore_warnings(category=(FutureWarning)): | ||
check_set_output_transform(estimator.__class__.__name__, estimator) | ||
|
||
|
@@ -505,7 +500,6 @@ def test_set_output_transform_configured(estimator, check_func): | |
f"Skipping {check_func.__name__} for {name}: Does not support" | ||
" set_output API yet" | ||
) | ||
_set_checking_parameters(estimator) | ||
with ignore_warnings(category=(FutureWarning)): | ||
check_func(estimator.__class__.__name__, estimator) | ||
|
||
|
@@ -523,8 +517,6 @@ def test_check_inplace_ensure_writeable(estimator): | |
else: | ||
raise SkipTest(f"{name} doesn't require writeable input.") | ||
|
||
_set_checking_parameters(estimator) | ||
|
||
# The following estimators can work inplace only with certain settings | ||
if name == "HDBSCAN": | ||
estimator.set_params(metric="precomputed", algorithm="brute") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now not required, and there's no need for a replacement since we now pass instances to estimator checks and not classes.