-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Run common checks on estimators with non default parameters #17441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rth
wants to merge
4
commits into
scikit-learn:main
Choose a base branch
from
rth:estimator-checks-non-default-params
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+311
−0
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
import pprint | ||
import itertools | ||
from inspect import signature | ||
from typing import Optional, List, Dict, Any | ||
|
||
from sklearn.base import is_regressor | ||
from sklearn.datasets import load_iris | ||
from sklearn.tree._classes import BaseDecisionTree | ||
from sklearn.ensemble._forest import BaseForest | ||
from sklearn.ensemble._gb import BaseGradientBoosting | ||
from sklearn.utils import all_estimators | ||
from sklearn.utils.estimator_checks import ( | ||
_enforce_estimator_tags_y, | ||
parametrize_with_checks, | ||
) | ||
from sklearn.utils._testing import ignore_warnings | ||
|
||
|
||
categorical_params = [ | ||
"solver", | ||
"algorithm", | ||
"loss", | ||
"strategy", | ||
"selection", | ||
"criterion", | ||
"multi_class", | ||
"kernel", | ||
"affinity", | ||
"linkage", | ||
"metric", | ||
"init", | ||
"eigen_solver", | ||
"initial_strategy", | ||
"imputation_order", | ||
"encode", | ||
"learning_method", | ||
"method", | ||
"fit_algorithm", | ||
"norm", | ||
"svd_solver", | ||
"order", | ||
"mode", | ||
"output_distribution", | ||
] | ||
bool_params = [ | ||
"fit_prior", | ||
"fit_intercept", | ||
"positive", | ||
"normalize", | ||
"dual", | ||
"average", | ||
"shuffle", | ||
"whiten", | ||
"path_method", | ||
"include_bias", | ||
"interaction_only", | ||
"standardize", | ||
"with_centering", | ||
"with_scaling", | ||
"with_mean", | ||
"with_std", | ||
] | ||
|
||
|
||
class FakeParam(str): | ||
"""Fake string parameter | ||
|
||
This object always returns False when compared to other values, | ||
but remember values it was compared with. It is used to determine | ||
valid values for a categorical parameter. | ||
|
||
Examples | ||
-------- | ||
>>> fp = FakeParam() | ||
>>> isinstance(fp, str) | ||
True | ||
>>> fp | ||
fake-param | ||
>>> fp == 'a' | ||
False | ||
>>> fp.values | ||
{'a'} | ||
""" | ||
|
||
def __init__(self): | ||
self.values = set() | ||
|
||
def __repr__(self): | ||
return "fake-param" | ||
|
||
def __hash__(self): | ||
return hash(str(self)) | ||
|
||
def __eq__(self, other): | ||
self.values.add(other) | ||
return False | ||
|
||
|
||
def detect_valid_categorical_params( | ||
Estimator, param: str = "solver" | ||
) -> Optional[List]: | ||
"""Detect valid parameters for an estimator | ||
|
||
|
||
Returns | ||
------- | ||
set or None: a set of valid parameters of None if | ||
it they could not be determined (or the | ||
estimator has no such parameter name) | ||
|
||
Example | ||
------- | ||
>>> from sklearn.linear_model import LogisticRegression | ||
>>> detect_valid_categorical_params(LogisticRegression, param="solver") | ||
['lbfgs', 'liblinear', 'newton-cg', 'sag', 'saga'] | ||
""" | ||
name = Estimator.__name__ | ||
est_signature = signature(Estimator) | ||
if param not in est_signature.parameters: | ||
return None | ||
|
||
# Special cases | ||
if param == "criterion" and issubclass( | ||
Estimator, (BaseDecisionTree, BaseForest, BaseGradientBoosting) | ||
): | ||
# hardcode this case as the FakeParam apporach then doesn't work with | ||
# `param in valid_list` checks. | ||
from sklearn.tree._classes import CRITERIA_CLF, CRITERIA_REG | ||
|
||
if is_regressor(Estimator): | ||
return list(CRITERIA_REG) | ||
else: | ||
return list(CRITERIA_CLF) | ||
elif param == "loss": | ||
# hardcode a few other cases that can't be auto-detected | ||
from sklearn.linear_model import ( | ||
SGDClassifier, | ||
SGDRegressor, | ||
) | ||
|
||
if name == "PassiveAggressiveClassifier": | ||
return ["hinge", "squared_hinge"] | ||
elif issubclass(Estimator, SGDClassifier): | ||
return [ | ||
"hinge", | ||
"log", | ||
"modified_huber", | ||
"squared_hinge", | ||
"perceptron", | ||
] | ||
elif name == "PassiveAggressiveRegressor": | ||
return ["epsilon_insensitive", "squared_epsilon_insensitive"] | ||
elif issubclass(Estimator, SGDRegressor): | ||
return [ | ||
"squared_loss", | ||
"huber", | ||
"epsilon_insensitive", | ||
"squared_epsilon_insensitive", | ||
] | ||
elif param == "kernel": | ||
if name in [ | ||
"GaussianProcessClassifier", | ||
"GaussianProcessRegressor", | ||
"Nystroem", | ||
]: | ||
# these kernels are not string but instances, skip | ||
return None | ||
elif name == "KernelPCA": | ||
return ["linear", "poly", "rbf", "sigmoid", "cosine"] | ||
elif param == "affinity": | ||
if name in ["AgglomerativeClustering", "FeatureAgglomeration"]: | ||
return ["euclidean", "l1", "l2", "manhattan", "cosine"] | ||
elif param == "linkage": | ||
if name in ["AgglomerativeClustering", "FeatureAgglomeration"]: | ||
return ["ward", "complete", "average", "single"] | ||
elif param == "norm": | ||
if name in ["HashingVectorizer", "TfidfTransformer"]: | ||
# Vectorizers are not suppored in common tests for now | ||
return None | ||
elif name == "Normalizer": | ||
return ["l1", "l2", "max"] | ||
elif name == "ComplementNB": | ||
return [True, False] | ||
elif param == "order": | ||
if name == "PolynomialFeatures": | ||
return ["F", "C"] | ||
elif param == "mode": | ||
if name == "GenericUnivariateSelect": | ||
return ["percentile", "k_best", "fpr", "fdr", "fwe"] | ||
elif name in ["KNeighborsTransformer", "RadiusNeighborsTransformer"]: | ||
return ["distance", "connectivity"] | ||
|
||
fp = FakeParam() | ||
|
||
X, y = load_iris(return_X_y=True) | ||
|
||
# Auto-detect valid string parameters with FakeParam | ||
try: | ||
args = {param: fp} | ||
est = Estimator(**args) | ||
y = _enforce_estimator_tags_y(est, y) | ||
with ignore_warnings(): | ||
est.fit(X, y) | ||
except Exception: | ||
if not fp.values: | ||
raise | ||
if not fp.values: | ||
raise ValueError( | ||
f"{Estimator.__name__}: {param}={fp.values} " | ||
f"should contain at least one element" | ||
) | ||
|
||
return list(sorted(fp.values)) | ||
|
||
|
||
def detect_all_params(Estimator) -> Dict[str, List]: | ||
"""Detect all valid parameters for an estimator | ||
|
||
Example | ||
------- | ||
>>> from sklearn.linear_model import LogisticRegression | ||
>>> detect_all_params(LogisticRegression) | ||
{'solver': ['lbfgs', 'liblinear', 'newton-cg', 'sag', 'saga'], | ||
'multi_class': ['auto', 'multinomial', 'ovr'], | ||
'fit_intercept': [False, True], | ||
'dual': [False, True]} | ||
""" | ||
res: Dict[str, Any] = {} | ||
name = Estimator.__name__ | ||
if name in ["ClassifierChain", "RegressorChain"]: | ||
# skip meta-estimators for now | ||
return res | ||
|
||
est_signature = signature(Estimator) | ||
for param_name in est_signature.parameters: | ||
if param_name in categorical_params: | ||
values = detect_valid_categorical_params( | ||
Estimator, param=param_name | ||
) | ||
if values is not None: | ||
res[param_name] = values | ||
elif param_name in bool_params: | ||
res[param_name] = [False, True] | ||
return res | ||
|
||
|
||
def _merge_dict_product(**kwargs) -> List[Dict]: | ||
"""Merge the cathesian product of dictionaries with lists | ||
|
||
Example | ||
------- | ||
>>> _merge_dict_product(a=[1, 2], b=[True, False], c=['O']) | ||
[{'a': 1, 'b': True, 'c': 'O'}, | ||
{'a': 1, 'b': False, 'c': 'O'}, | ||
{'a': 2, 'b': True, 'c': 'O'}, | ||
{'a': 2, 'b': False, 'c': 'O'}] | ||
""" | ||
tmp = [] | ||
for key, val in kwargs.items(): | ||
tmp.append([{key: el} for el in val]) | ||
|
||
res = [] | ||
for val in itertools.product(*tmp): | ||
row: Dict[str, Any] = {} | ||
for el in val: | ||
row.update(el) | ||
res.append(row) | ||
|
||
return res | ||
|
||
|
||
def _make_all_estimator_instances(verbose=False): | ||
|
||
X_raw, y_raw = load_iris(return_X_y=True) | ||
|
||
for name, Estimator in all_estimators( | ||
type_filter=["transformer", "cluster", "classifier", "regressor"] | ||
): | ||
valid_params = detect_all_params(Estimator) | ||
if verbose: | ||
print(f"{name}") | ||
if valid_params: | ||
for params in _merge_dict_product(**valid_params): | ||
# Check that we can train Iris, otherwise parameters | ||
# are likely incompatible | ||
try: | ||
est = Estimator(**params) | ||
y = _enforce_estimator_tags_y(est, y_raw) | ||
with ignore_warnings(): | ||
est.fit(X_raw.copy(), y) | ||
# Parameters should be OK | ||
yield Estimator(**params) | ||
except Exception: | ||
# Likely wrong parameters, skipping | ||
pass | ||
|
||
if verbose: | ||
pprint.pp(valid_params, sort_dicts=True) | ||
|
||
|
||
@parametrize_with_checks(list(_make_all_estimator_instances())) | ||
def test_common_non_default(estimator, check): | ||
check(estimator) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Print the list of tested estimator | ||
list(_make_all_estimator_instances(verbose=True)) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to do this by introducing type hints for categorical params? Or by parsing the docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For type annotations we would need PEP 586 for literal types and backport it for Python <3.8. That's probably indeed the best solution long term.
I'm not very keen on parsing docstrings, as that can break easily and is less reliable than the current apporach. For instance, estimators taking metrics names as input, all valid names are not always listed, sometimes it's just a link to the user guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW the
typing-extensions
package is an official source of backports too.And Hypothesis' introspection can inspect
typing
ortyping_extensions.Literal[...]
and automatically sample from the parameters 😁There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are working toward having typing in #17799
Moving forward we are going to run into invalid parameter combinations when sampling from the parameter space. Libraries such as ConfigSpace uses a combination of conditional clauses and forbidden clauses to help define the relationships between parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the references! I think using hypothesis to sample this parameter space long term could be intersting, as it could also generalize to float parameters. Also to limit the runtime somewhat once the exhaustive combination of parameters becomes untracktable.