Skip to content

MAINT Param validation: apply skip nested validation to all functions #26495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 22, 2023

Conversation

jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Jun 2, 2023

Follow-up of #25815
This PR sets prefer_skip_nested_validation to all validated functions.

There are situations where we don't want to skip inner validation (prefer_skip_nested_validation=False):

  • the user passes an unfitted estimator instance
  • the user passes a callable and the args for the callable as a dict (e.g. metric and metric_params). Note that if the user only passes a callable we want to skip inner validation because the args passed to the callable come from us and not from the user.
  • the function is just a wrapper around an estimator class and hence only performs partial validation.

Functions receiving cv objects can skip inner validation because cv objects are not validated yet. When we decide to validate cv objects we'll need to revisit this.

@jeremiedbb jeremiedbb added No Changelog Needed Validation related to input validation labels Jun 2, 2023
@jeremiedbb jeremiedbb added this to the 1.3 milestone Jun 2, 2023
@adrinjalali
Copy link
Member

Should we then make prefer_skip_nested_validation a required arg, so that we always set it and think about how to set it?

@jeremiedbb
Copy link
Member Author

That's what I did :) see the new signature:
def validate_params(parameter_constraints, *, prefer_skip_nested_validation):

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it'd be feasible for you to put a comment wherever prefer_skip_nested_validation=False? Cause it's almost always True

@@ -680,7 +681,8 @@ def _set_reach_dist(
"core_distances": [np.ndarray],
"ordering": [np.ndarray],
"eps": [Interval(Real, 0, None, closed="both")],
}
},
prefer_skip_nested_validation=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this here True but compute_optics_graph has it as False?

Copy link
Member Author

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali here's a list of all places where I set it to False with the reason for each.

Copy link
Member Author

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@adrinjalali here's a list of all places where I set it to False with the reason for each.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the ones which are False and the reason is something other than them being an estimator wrapper, it'd be nice to have the reason commented in the code for future maintainers.

@@ -559,7 +560,8 @@ def _more_tags(self):
{
"X": ["array-like"],
"axis": [Options(Integral, {0, 1})],
}
},
prefer_skip_nested_validation=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this false but the other ones in this file True?

@@ -296,6 +296,7 @@ def test_function_param_validation(func_module):

PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minmax_scale and similar ones are not?

@@ -305,6 +304,7 @@ def test_function_param_validation(func_module):
("sklearn.decomposition.dict_learning", "sklearn.decomposition.DictionaryLearning"),
("sklearn.decomposition.fastica", "sklearn.decomposition.FastICA"),
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
("sklearn.preprocessing.maxabs_scale", "sklearn.preprocessing.MaxAbsScaler"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should robust_scale also be here?

@glemaitre glemaitre merged commit 04575f5 into scikit-learn:main Jun 22, 2023
@glemaitre
Copy link
Member

Thanks @jeremiedbb.

@jeremiedbb jeremiedbb added the To backport PR merged in master that need a backport to a release branch defined based on the milestone. label Jun 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
No Changelog Needed To backport PR merged in master that need a backport to a release branch defined based on the milestone. Validation related to input validation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants