Skip to content

API Proposal: Genearlized Cross-Validation and Early Stopping #1626

Closed
@amueller

Description

@amueller

This is a proposal to to resolve two API issues in sklearn:

  • Generalized Cross Validation
  • Early stopping

Why should we care about that?

With generalized cross validation I mean finding the best setting of some parameter without refitting the entire model. This is currently implemented for RFE and some linear models via a EstimatorCV. These don't work well together with GridSearchCV as might be required in a Pipeline or when more than one parameter needs to be found.
Also, a similar functionality would be great for other models like GradientBoosting (for n_estimators) and all tree-based methods (for max_depth).

With early stopping I mean saving computations when more computation doesn't improve the result. We don't have that yet but it would be a great (maybe even necessary) feature for SGD based methods and bagging methods (random forests and extra trees). Note that early stopping needs the use of a validation set to evaluate the model.

How can we solve that?

Let's start with the generalized cross validation.
We need it to work together with GridSearchCV.
This will definitely require changes in both GridSearchCV and the estimators.
My idea:

  1. give the estimator an iterable of values you want to try, i.e. max_depth=range(1, 10). During fit the estimator will fit in a way that it can produce predictions for all of these values.
  2. When predict is called, the estimator will return a dict with keys the parameter values and values the prediction values for these parameters. (we could also add a new predict_all function but I'm not sure about that).

GridSearchCV could then simply incooperate these values into the grid-search result.
For that GridSearchCV needs to be able to ask the estimator for which parameters it can do generalized CV and just pass on the list of parameters it got there.

So now to early stopping. The reason I want to treat the two problems as one is that early stopping is basically a lazy form of generalized cross-validation.
So

  1. you would provide the estimator with an iterable, i.e. n_iter=range(1, 100, 10).
  2. The estimator fits for all these values (as implemented as above). But for each setting, it also evaluates on a validation set, and if there is only a small improvement, training will stop.

I would provide the validation set that is used either as a parameter to __init__ or fit (not sure). So it would be enough to add two parameters to the estimator: early_stopping_tolerance and early_stopping_set=None (if it is None, no early stopping).

There are two choices that I made here:

  1. provide a separate validation set, not generate one from the training set on the fly.
    This is again so that this can be used inside GridSearchCV. And doesn't really add that much overhead if the user doesn't use GridSearchCV (why would they do that any way?)
    It is also very explicit and gives the user a lot of control.
  2. Provide the parameter settings as an iterable, not a maximum. The reason for that is that you probably don't want to evaluate the validation set every iteration, but maybe every k iterations. What's k? another parameter? I feel like an iterable is a good way to specify this. Also, it allows for a unified interface with the generalized cross-validation case.

Restrictions

In pipelines, this will only work for the last estimator. So I'm not sure to do this with RFE for example.

Do we really need this?

The changes I proposed above are quite big in some sense, but I think the two issues need to be resolved. If you have any better idea, feel free to explain it ;)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions