diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index ec25afeb6b5f8..20ab0665eec88 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -34,10 +34,14 @@ def test_function_param_validation(func_module): required_params = [ p.name for p in func_sig.parameters.values() if p.default is p.empty ] - required_params = { - p: generate_valid_param(make_constraint(parameter_constraints[p][0])) - for p in required_params - } + valid_required_params = {} + for param_name in required_params: + if parameter_constraints[param_name] == "no_validation": + valid_required_params[param_name] = 1 + else: + valid_required_params[param_name] = generate_valid_param( + make_constraint(parameter_constraints[param_name][0]) + ) # check that there is a constraint for each parameter if func_params: @@ -51,18 +55,23 @@ def test_function_param_validation(func_module): param_with_bad_type = type("BadType", (), {})() for param_name in func_params: + constraints = parameter_constraints[param_name] + + if constraints == "no_validation": + # This parameter is not validated + continue + match = ( rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead." ) # First, check that the error is raised if param doesn't match any valid type. with pytest.raises(ValueError, match=match): - func(**{**required_params, param_name: param_with_bad_type}) + func(**{**valid_required_params, param_name: param_with_bad_type}) # Then, for constraints that are more than a type constraint, check that the # error is raised if param does match a valid type but does not match any valid # value for this type. - constraints = parameter_constraints[param_name] constraints = [make_constraint(constraint) for constraint in constraints] for constraint in constraints: @@ -72,4 +81,4 @@ def test_function_param_validation(func_module): continue with pytest.raises(ValueError, match=match): - func(**{**required_params, param_name: bad_value}) + func(**{**valid_required_params, param_name: bad_value}) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 3832f6cbc4d20..880343d61d8f3 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -19,9 +19,12 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): Parameters ---------- - parameter_constraints : dict - A dictionary `param_name: list of constraints`. A parameter is valid if it - satisfies one of the constraints from the list. Constraints can be: + parameter_constraints : dict or {"no_validation"} + If "no_validation", validation is skipped for this parameter. + + If a dict, it must be a dictionary `param_name: list of constraints`. + A parameter is valid if it satisfies one of the constraints from the list. + Constraints can be: - an Interval object, representing a continuous or discrete range of numbers - the string "array-like" - the string "sparse matrix" @@ -47,6 +50,10 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): for param_name, param_val in params.items(): constraints = parameter_constraints[param_name] + + if constraints == "no_validation": + continue + constraints = [make_constraint(constraint) for constraint in constraints] for constraint in constraints: diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 80e242e166a9e..63be54d2db03b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4064,6 +4064,12 @@ def check_param_validation(name, estimator_orig): methods = [method for method in fit_methods if hasattr(estimator_orig, method)] for param_name in estimator_params: + constraints = estimator_orig._parameter_constraints[param_name] + + if constraints == "no_validation": + # This parameter is not validated + continue + match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead." err_msg = ( f"{name} does not raise an informative error message when the " @@ -4082,7 +4088,6 @@ def check_param_validation(name, estimator_orig): # Then, for constraints that are more than a type constraint, check that the # error is raised if param does match a valid type but does not match any valid # value for this type. - constraints = estimator_orig._parameter_constraints[param_name] constraints = [make_constraint(constraint) for constraint in constraints] for constraint in constraints: diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index 9675ce3c217f5..a8b41e0fb6b5e 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -522,3 +522,22 @@ def f(param): FutureWarning, match="Passing an int for a boolean parameter is deprecated" ): f(1) + + +def test_no_validation(): + """Check that validation can be skipped for a parameter.""" + + @validate_params({"param1": [int, None], "param2": "no_validation"}) + def f(param1=None, param2=None): + pass + + # param1 is validated + with pytest.raises(ValueError, match="The 'param1' parameter"): + f(param1="wrong") + + # param2 is not validated: any type is valid. + class SomeType: + pass + + f(param2=SomeType) + f(param2=SomeType())