Skip to content

ENH Generally avoid nested param validation #25815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 24, 2023

Conversation

jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Mar 10, 2023

(Alternative to) Closes #25493

It happens that estimators or functions call other public estimators or functions, sometimes in a loop, for which parameter validation is done each time. Param validation is cheap but definitely not free and this is something we want to avoid.

In addition, once the first validation is done we make sure that internally we pass the appropriate parameters to inner estimators or functions, so these nested validation are useless.

In #25493, I proposed to add an option to the config_context context manager to skip parameter validation locally. I think it will be cumbersome to add this everywhere necessary. And sometimes it's really not straightforward to see that an estimator calls a public function at some point. For instance MiniBatchDictionaryLearning -> _minibatch_step -> _sparse_encode -> _sparse_encode_precomputed -> Lasso.

This is why this PR proposes to go one step further and introduce a new decorator to decorate the fit methods of all estimators.

@_fit_context()
def fit(self, X, y):
    ...

where _fit_context does the param validation of the estimator and returns fit with further param validation disabled.

There is no need to introduce a new decorator for the param validation of functions, we can reuse the _validate_parameters decorator.

One thing to consider is that sometimes we do want to keep nested param validation: for functions that are just wrappers around classes (they delegate the validation to the underlying class) and for meta-estimators for which we still need to validate the inner estimator when its fit method is called. This is why _fit_context has a skip_nested_validation parameter to chose whether to keep the nested validation or not.

Side note : this PR only makes the switch for a single estimator to showcase the feature and ease the review. Doing it for all estimators will follow in a follow-up PR.

Side note 2 : Such a decorator could be useful for other applications. For instance to properly tear down a callback in case the estimator's fit is interrupted.

@jeremiedbb jeremiedbb added the Validation related to input validation label Mar 10, 2023
@jeremiedbb jeremiedbb changed the title MAINT Generally avoid nested param validation ENH Generally avoid nested param validation Mar 10, 2023
@glemaitre glemaitre self-requested a review March 13, 2023 14:22
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small first round

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I am fine with the current implementation. Just wondering if it would be time to add a section regarding the parameter validation on the develop.rst page. Of course, it should be done in another PR.

@jeremiedbb
Copy link
Member Author

Regarding concerns raised in some irl discussions about auto-completion, discoverability, tracebacks:

  • In Ipython or Jupyter, getting information about fit, i.e. fit? or fit??, gives the same results as before and leaves no trace of the decorator.
  • Auto-completion works the same as well in Ipython or Jupyter. I also checked vscode, but don't know about other editors.
  • In case of exception the traceback has an additional line:
    Traceback (most recent call last):
    File "/home/jeremie/R/sklearn/scikit-learn-2/script.py", line 5, in <module>
      dl.fit(X)
    File "/home/jeremie/R/sklearn/scikit-learn-2/sklearn/base.py", line 1119, in wrapper   # this is
      return fit_method(estimator, *args, **kwargs)                                        # additional
    File "/home/jeremie/R/sklearn/scikit-learn-2/sklearn/decomposition/_dict_learning.py", line 2425, in fit
      self._minibatch_step(X_train[batch], dictionary, self._random_state, i)
    File "/home/jeremie/R/sklearn/scikit-learn-2/sklearn/decomposition/_dict_learning.py", line 2234, in _minibatch_step
      raise ValueError("A somewhat informative error message")
    ValueError: A somewhat informative error message
    I think it's acceptable and doesn't hurt too much the readability of the traceback

Let me know if there are more things to test that I haven't thought about

@jeremiedbb
Copy link
Member Author

I noticed recently that the test suite duration has increased a lot. I think that the param validation has its share. With the new config option we could add a pytest session setup that deactivates param validation for the whole run (we would manually re-activate it for the param validation specific tests).

@jeremiedbb
Copy link
Member Author

@thomasjpfan I'd like to have your opinion on this if you have some time available.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

sklearn/base.py Outdated
@@ -1087,3 +1088,33 @@ def is_outlier_detector(estimator):
True if estimator is an outlier detector and False otherwise.
"""
return getattr(estimator, "_estimator_type", None) == "outlier_detector"


def _fit_context(*, skip_nested_validation=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fit_context does not describe what the decorator does. Maybe:

Suggested change
def _fit_context(*, skip_nested_validation=True):
def _validate_param_context(*, prefer_skip_nested_validation):

I'm open to better names.

Also, I think prefer_skip_nested_validation should be passed in all the time. It makes it easier to understand the indention from the call-site. (See other comment about prefer_skip_nested_validation.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used a generic name on purpose. In the future we might want to run fit within another context manager, for instance callbacks need that fit is ran in a try finally to have a proper tear down.

Also, I think prefer_skip_nested_validation should be passed in all the time.

I'm ok passing it all the time. My original intention was to skip nested by default (and not passing it explicitely) because this is almost always what we want. This way when we pass it explicitely False, it's more visible that this is different. What do you think ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My original intention was to skip nested by default (and not passing it explicitely) because this is almost always what we want. This way when we pass it explicitely False, it's more visible that this is different. What do you think ?

Given that the function name is _fit_context, I do not think the default of "turning off nested validation" is obviously True. For scikit-learn, there is a sizable amount of estimators that still wants nested validation. Off the top of my head:

  • All meta-estimators (validating the estimator parameters)
  • All estimators that accept a scorer (validating the parameters for the metric)
  • All estimators that accept a splitter (validating the splitter parameters)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it a required param for _fit_context.

For functions I left the default to True, otherwise I need to set it for all currently decorated functions as part of this PR. I thought I could do that in a separate PR. Do you prefer making the change for all function here ?

@betatim
Copy link
Member

betatim commented Apr 4, 2023

A comment about tracebacks: I agree that this PR adds one additional line/level, but set_output also adds one line. So we are now already at two, which is a significant fraction of the total length. I don't know if this makes me -1 or +1 regarding adding "just one more line", but I think we should think about this a bit to see if we can find an alternative solution. Maybe not because we need it for this, but because we will need it at some point in the future (assuming the number of decorators/wrappers will increase not decrease over time).

And while it seems like a small thing, it does irritate me already now to have this "obscure" output wrapper thing in the traceback. It is like taking care of a fly in your soup, not a big deal, easy to do but somehow it still interrupts your lunch :D

@lorentzenchr
Copy link
Member

I agree with @betatim. A good place to discuss the big picture of avoiding parameter validation is #21804.

@jeremiedbb jeremiedbb added this to the 1.3 milestone Apr 7, 2023
@jeremiedbb
Copy link
Member Author

jeremiedbb commented Apr 7, 2023

I added it to the 1.3 milestone because I think we don't want to release with the perf regressions reported by @ogrisel

@lorentzenchr
Copy link
Member

Meanwhile I read the comments of this and the former PR, in particular #25493 (review), and the approach seems fine to me. (It would be nice to have such discussions in a dedicated issue an not scattered over PRs).
IIUC, a std user should not need this. One main usage is for scikit-learn itself.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@@ -154,6 +155,10 @@ def validate_params(parameter_constraints):
Note that the *args and **kwargs parameters are not validated and must not be
present in the parameter_constraints dictionary.

prefer_skip_nested_validation : bool, default=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call this skip_nested_validation instead? The prefer prefix makes me think that you are only indicating a preference here, but it looks like parameter validation is turned off if you set this to True?

@thomasjpfan has a comment where he refers to a comment where he (I think) expressed a preference for adding the prefer_ prefix but I can't find it :-/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my first naming, but the prefer prefix was suggested here #25815 (comment).

As I don't have a strong opinion on the naming, I used the suggested one even though I agree it's not really a preference but more a flag. I let you and @thomasjpfan decide 😄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a preference because when prefer_skip_nested_validation==False and global_skip_validation==True, then the nested validation is skipped.

Copy link
Member

@thomasjpfan thomasjpfan Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concretely, here is a table that maps from global config & parameter to what I think is the expected behavior:

global_skip_validation prefer_skip_nested_validation skip nested validation?
True True True
True False True
False True True
False False False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. I hadn't even thought this far yet. But it shows that your expectation and mine are different. I had assumed (and not realised it isn't true) that the context manager always "wins". That means if globally skip_validation=True and we enter a context maneger that sets skip_validation=False then validation isn't skipped.

Did you discuss why it makes sense to implement it the way you did? If we stick with it, I think we should add a comment to the context manager to make it clear that all the context manager can do is enable skipping. It can't overrule the global setting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The table above reflects how I imagined it and the proposed implementation follows this behavior.

I had assumed that the context manager always "wins"

skip_parameter_validation from config_context and prefer_skip_nested_validation from validate_params have a very different meaning:

  • the former is public and is meant to allow third party devs using sklearn public tools in computationally intensive sections without having to pay the overhead of the validation. It completely deactivate validation though and hence can lead to uninformative error messages.
  • the latter is private and is just there for performance considerations. Validation of the outermost function or estimator still happens so there will always be an informative error message for bad param. Once the public facing interface is validated, there's no need to validate inner calls to functions because we now that internally we never pass bad parameters (this is guaranteed, as much as possible, by our test suite). So all nested validations are just no-ops with a sometimes significant overhead.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned over the default prefer_skip_nested_validation=True for functions. If a function uses a scorer then the scorer will skip it's validation. If a function uses an estimator and the function did not validated all the estimator's parameters, then the estimator may be invalid.

@@ -154,6 +155,10 @@ def validate_params(parameter_constraints):
Note that the *args and **kwargs parameters are not validated and must not be
present in the parameter_constraints dictionary.

prefer_skip_nested_validation : bool, default=True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a preference because when prefer_skip_nested_validation==False and global_skip_validation==True, then the nested validation is skipped.

@jeremiedbb
Copy link
Member Author

I'm concerned over the default prefer_skip_nested_validation=True for functions.

I made it False by default. I think we can even make it a required param so that we're sure we don't forget to set it appropriately for every function. But let's do that in a follow-up PR to not pollute the diff of this PR.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a quick suggestion to make one of the tests easier to follow.

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
@ogrisel ogrisel enabled auto-merge (squash) May 24, 2023 14:55
@ogrisel ogrisel merged commit 1284767 into scikit-learn:main May 24, 2023
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants