From e42ba2caed2b1a75f7340f7366bcbff68b65ba8f Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 13 Jun 2022 15:14:25 +0200 Subject: [PATCH 1/4] allow validation skip --- sklearn/utils/_param_validation.py | 5 +++++ sklearn/utils/estimator_checks.py | 7 ++++++- sklearn/utils/tests/test_param_validation.py | 19 +++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 1d34886267e6c..c0672d792ec79 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -28,6 +28,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): - None, meaning that None is a valid value for the parameter - any type, meaning that any instance of this type is valid - a StrOptions object, representing a set of strings + - the string "no validation", skipping validation of the parameter params : dict A dictionary `param_name: param_value`. The parameters to validate against the @@ -44,6 +45,10 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): for param_name, param_val in params.items(): constraints = parameter_constraints[param_name] + + if constraints == ["no validation"]: + continue + constraints = [make_constraint(constraint) for constraint in constraints] for constraint in constraints: diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 80e242e166a9e..056a1cd56349b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4064,6 +4064,12 @@ def check_param_validation(name, estimator_orig): methods = [method for method in fit_methods if hasattr(estimator_orig, method)] for param_name in estimator_params: + constraints = estimator_orig._parameter_constraints[param_name] + + if constraints == ["no validation"]: + # This parameter is not validated + continue + match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead." err_msg = ( f"{name} does not raise an informative error message when the " @@ -4082,7 +4088,6 @@ def check_param_validation(name, estimator_orig): # Then, for constraints that are more than a type constraint, check that the # error is raised if param does match a valid type but does not match any valid # value for this type. - constraints = estimator_orig._parameter_constraints[param_name] constraints = [make_constraint(constraint) for constraint in constraints] for constraint in constraints: diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index 35da751b6fa52..57ff8b1862d9c 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -464,3 +464,22 @@ def f(param): err_msg = str(exc_info.value) assert "auto" in err_msg assert "warn" not in err_msg + + +def test_no_validation(): + """Check that validation can be skipped for a parameter.""" + + @validate_params({"param1": [int, None], "param2": ["no validation"]}) + def f(param1=None, param2=None): + pass + + # param1 is validated + with pytest.raises(ValueError, match="The 'param1' parameter"): + f(param1="wrong") + + # param2 is not validated: any type is valid. + class SomeType: + pass + + f(param2=SomeType) + f(param2=SomeType()) From 5c27d3e28fa7e878ecb699a51b5b2974fc64e730 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Mon, 13 Jun 2022 15:48:19 +0200 Subject: [PATCH 2/4] lint --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 056a1cd56349b..b80b97afbc3a6 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4065,7 +4065,7 @@ def check_param_validation(name, estimator_orig): for param_name in estimator_params: constraints = estimator_orig._parameter_constraints[param_name] - + if constraints == ["no validation"]: # This parameter is not validated continue From cfab1e9c91abf7a5871c0e80995a4b5f05c56e69 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 21 Jun 2022 15:36:57 +0200 Subject: [PATCH 3/4] no_validation --- sklearn/utils/_param_validation.py | 6 +++--- sklearn/utils/tests/test_param_validation.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 7810821454af8..b42fe85eb8c16 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -25,13 +25,13 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): - an Interval object, representing a continuous or discrete range of numbers - the string "array-like" - the string "sparse matrix" - - the string "random state" + - the string "random_state" - callable - None, meaning that None is a valid value for the parameter - any type, meaning that any instance of this type is valid - a StrOptions object, representing a set of strings - the string "boolean" - - the string "no validation", skipping validation of the parameter + - the string "no_validation", skipping validation of the parameter params : dict A dictionary `param_name: param_value`. The parameters to validate against the @@ -49,7 +49,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): for param_name, param_val in params.items(): constraints = parameter_constraints[param_name] - if constraints == ["no validation"]: + if constraints == ["no_validation"]: continue constraints = [make_constraint(constraint) for constraint in constraints] diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index e467aec129d99..86ec5f54a16fa 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -527,7 +527,7 @@ def f(param): def test_no_validation(): """Check that validation can be skipped for a parameter.""" - @validate_params({"param1": [int, None], "param2": ["no validation"]}) + @validate_params({"param1": [int, None], "param2": ["no_validation"]}) def f(param1=None, param2=None): pass From 7b1dc4232dca41dce11efdf443d9ebc97e41bf3e Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Wed, 22 Jun 2022 15:24:33 +0200 Subject: [PATCH 4/4] string "no_validation" --- sklearn/tests/test_public_functions.py | 23 ++++++++++++++------ sklearn/utils/_param_validation.py | 12 +++++----- sklearn/utils/estimator_checks.py | 2 +- sklearn/utils/tests/test_param_validation.py | 2 +- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index ec25afeb6b5f8..20ab0665eec88 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -34,10 +34,14 @@ def test_function_param_validation(func_module): required_params = [ p.name for p in func_sig.parameters.values() if p.default is p.empty ] - required_params = { - p: generate_valid_param(make_constraint(parameter_constraints[p][0])) - for p in required_params - } + valid_required_params = {} + for param_name in required_params: + if parameter_constraints[param_name] == "no_validation": + valid_required_params[param_name] = 1 + else: + valid_required_params[param_name] = generate_valid_param( + make_constraint(parameter_constraints[param_name][0]) + ) # check that there is a constraint for each parameter if func_params: @@ -51,18 +55,23 @@ def test_function_param_validation(func_module): param_with_bad_type = type("BadType", (), {})() for param_name in func_params: + constraints = parameter_constraints[param_name] + + if constraints == "no_validation": + # This parameter is not validated + continue + match = ( rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead." ) # First, check that the error is raised if param doesn't match any valid type. with pytest.raises(ValueError, match=match): - func(**{**required_params, param_name: param_with_bad_type}) + func(**{**valid_required_params, param_name: param_with_bad_type}) # Then, for constraints that are more than a type constraint, check that the # error is raised if param does match a valid type but does not match any valid # value for this type. - constraints = parameter_constraints[param_name] constraints = [make_constraint(constraint) for constraint in constraints] for constraint in constraints: @@ -72,4 +81,4 @@ def test_function_param_validation(func_module): continue with pytest.raises(ValueError, match=match): - func(**{**required_params, param_name: bad_value}) + func(**{**valid_required_params, param_name: bad_value}) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index b42fe85eb8c16..880343d61d8f3 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -19,9 +19,12 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): Parameters ---------- - parameter_constraints : dict - A dictionary `param_name: list of constraints`. A parameter is valid if it - satisfies one of the constraints from the list. Constraints can be: + parameter_constraints : dict or {"no_validation"} + If "no_validation", validation is skipped for this parameter. + + If a dict, it must be a dictionary `param_name: list of constraints`. + A parameter is valid if it satisfies one of the constraints from the list. + Constraints can be: - an Interval object, representing a continuous or discrete range of numbers - the string "array-like" - the string "sparse matrix" @@ -31,7 +34,6 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): - any type, meaning that any instance of this type is valid - a StrOptions object, representing a set of strings - the string "boolean" - - the string "no_validation", skipping validation of the parameter params : dict A dictionary `param_name: param_value`. The parameters to validate against the @@ -49,7 +51,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): for param_name, param_val in params.items(): constraints = parameter_constraints[param_name] - if constraints == ["no_validation"]: + if constraints == "no_validation": continue constraints = [make_constraint(constraint) for constraint in constraints] diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index b80b97afbc3a6..63be54d2db03b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4066,7 +4066,7 @@ def check_param_validation(name, estimator_orig): for param_name in estimator_params: constraints = estimator_orig._parameter_constraints[param_name] - if constraints == ["no validation"]: + if constraints == "no_validation": # This parameter is not validated continue diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index 86ec5f54a16fa..a8b41e0fb6b5e 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -527,7 +527,7 @@ def f(param): def test_no_validation(): """Check that validation can be skipped for a parameter.""" - @validate_params({"param1": [int, None], "param2": ["no_validation"]}) + @validate_params({"param1": [int, None], "param2": "no_validation"}) def f(param1=None, param2=None): pass