Skip to content

model_selection.cross_validate doesn't pass groups argument to estimator #20349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
thomsentner opened this issue Jun 24, 2021 · 2 comments · Fixed by #26896
Closed

model_selection.cross_validate doesn't pass groups argument to estimator #20349

thomsentner opened this issue Jun 24, 2021 · 2 comments · Fixed by #26896

Comments

@thomsentner
Copy link

thomsentner commented Jun 24, 2021

Describe the workflow you want to enable

Nested cross validation is currently impossible with a grouped k-fold iterator in the inner loop. The currently proposed workflow by sklearn includes model_selection.cross_val_score or model_selection.cross_validate in the outer loop, and model_selection.GridSearchCV in the inner loop. However, model_selection.cross_validate only uses the groups parameter for its own cv instance, which also seems to be documented.

Describe your proposed solution

Pass the groups parameter from model_selection.cross_validate to the estimator through model_selection. _validation._fit_and_score. It actually seems like very minimal code changes would be necessary, passing along the groups parameter in three lines of code would be sufficient.

Additional context

sklearn's nested cross validation documentation actually assumes this functionality to be in place already, as GroupKFold is suggested as a compatible cv instance.

@jnothman
Copy link
Member

jnothman commented Jun 24, 2021 via email

@pavelkomarov
Copy link

pavelkomarov commented Jul 21, 2021

I also want this. I'm faced with a situation where I'm trying to train an LSTM on several time series. I want to reset the model state between series, so the model is making fresh predictions on each, but because GroupKFold doesn't pass through which points belong to which series, I'm having to make my own data divisions from scratch.

I just had this realization while writing a comment on a different thread and came back: If I've got X, y, groups as numpy arrays, and .split returns me indices train_ndx, val_ndx, then I can see which groups things belong to by indexing groups[train_ndx] and groups[val_ndx].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants