Skip to content

MNT Param validation: Allow to skip validation of a parameter #23602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ def test_function_param_validation(func_module):
required_params = [
p.name for p in func_sig.parameters.values() if p.default is p.empty
]
required_params = {
p: generate_valid_param(make_constraint(parameter_constraints[p][0]))
for p in required_params
}
valid_required_params = {}
for param_name in required_params:
if parameter_constraints[param_name] == "no_validation":
valid_required_params[param_name] = 1
else:
valid_required_params[param_name] = generate_valid_param(
make_constraint(parameter_constraints[param_name][0])
)

# check that there is a constraint for each parameter
if func_params:
Expand All @@ -51,18 +55,23 @@ def test_function_param_validation(func_module):
param_with_bad_type = type("BadType", (), {})()

for param_name in func_params:
constraints = parameter_constraints[param_name]

if constraints == "no_validation":
# This parameter is not validated
continue

match = (
rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
)

# First, check that the error is raised if param doesn't match any valid type.
with pytest.raises(ValueError, match=match):
func(**{**required_params, param_name: param_with_bad_type})
func(**{**valid_required_params, param_name: param_with_bad_type})

# Then, for constraints that are more than a type constraint, check that the
# error is raised if param does match a valid type but does not match any valid
# value for this type.
constraints = parameter_constraints[param_name]
constraints = [make_constraint(constraint) for constraint in constraints]

for constraint in constraints:
Expand All @@ -72,4 +81,4 @@ def test_function_param_validation(func_module):
continue

with pytest.raises(ValueError, match=match):
func(**{**required_params, param_name: bad_value})
func(**{**valid_required_params, param_name: bad_value})
13 changes: 10 additions & 3 deletions sklearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):

Parameters
----------
parameter_constraints : dict
A dictionary `param_name: list of constraints`. A parameter is valid if it
satisfies one of the constraints from the list. Constraints can be:
parameter_constraints : dict or {"no_validation"}
If "no_validation", validation is skipped for this parameter.

If a dict, it must be a dictionary `param_name: list of constraints`.
A parameter is valid if it satisfies one of the constraints from the list.
Constraints can be:
- an Interval object, representing a continuous or discrete range of numbers
- the string "array-like"
- the string "sparse matrix"
Expand All @@ -47,6 +50,10 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):

for param_name, param_val in params.items():
constraints = parameter_constraints[param_name]

if constraints == "no_validation":
continue

constraints = [make_constraint(constraint) for constraint in constraints]

for constraint in constraints:
Expand Down
7 changes: 6 additions & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4064,6 +4064,12 @@ def check_param_validation(name, estimator_orig):
methods = [method for method in fit_methods if hasattr(estimator_orig, method)]

for param_name in estimator_params:
constraints = estimator_orig._parameter_constraints[param_name]

if constraints == "no_validation":
# This parameter is not validated
continue

match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead."
err_msg = (
f"{name} does not raise an informative error message when the "
Expand All @@ -4082,7 +4088,6 @@ def check_param_validation(name, estimator_orig):
# Then, for constraints that are more than a type constraint, check that the
# error is raised if param does match a valid type but does not match any valid
# value for this type.
constraints = estimator_orig._parameter_constraints[param_name]
constraints = [make_constraint(constraint) for constraint in constraints]

for constraint in constraints:
Expand Down
19 changes: 19 additions & 0 deletions sklearn/utils/tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,22 @@ def f(param):
FutureWarning, match="Passing an int for a boolean parameter is deprecated"
):
f(1)


def test_no_validation():
"""Check that validation can be skipped for a parameter."""

@validate_params({"param1": [int, None], "param2": "no_validation"})
def f(param1=None, param2=None):
pass

# param1 is validated
with pytest.raises(ValueError, match="The 'param1' parameter"):
f(param1="wrong")

# param2 is not validated: any type is valid.
class SomeType:
pass

f(param2=SomeType)
f(param2=SomeType())