Skip to content

ENH allow to return train/test split indices in cross_validate #25659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
12 changes: 6 additions & 6 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ two ways:
- It allows specifying multiple metrics for evaluation.

- It returns a dict containing fit-times, score-times
(and optionally training scores as well as fitted estimators) in
addition to the test score.
(and optionally training scores, fitted estimators, train-test split indices)
in addition to the test score.

For single metric evaluation, where the scoring parameter is a string,
callable or None, the keys will be - ``['test_score', 'fit_time', 'score_time']``
Expand All @@ -220,10 +220,10 @@ following keys -

``return_train_score`` is set to ``False`` by default to save computation time.
To evaluate the scores on the training set as well you need to set it to
``True``.

You may also retain the estimator fitted on each training set by setting
``return_estimator=True``.
``True``. You may also retain the estimator fitted on each training set by
setting ``return_estimator=True``. Similarly, you may set
`return_indices=True` to retain the training and testing indices used to split
the dataset into train and test sets for each cv split.

The multiple metrics can be specified either as a list, tuple or set of
predefined scorer names::
Expand Down
9 changes: 8 additions & 1 deletion doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ Changelog
- |Efficiency| :class:`decomposition.MiniBatchDictionaryLearning` and
:class:`decomposition.MiniBatchSparsePCA` are now faster for small batch sizes by
avoiding duplicate validations.
:pr:`25490` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
:pr:`25490` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.ensemble`
.......................
Expand Down Expand Up @@ -215,6 +215,13 @@ Changelog
- |API| The `eps` parameter of the :func:`log_loss` has been deprecated and will be
removed in 1.5. :pr:`25299` by :user:`Omar Salman <OmarManzoor>`.

:mod:`sklearn.model_selection`
..............................

- |Enhancement| :func:`model_selection.cross_validate` accepts a new parameter
`return_indices` to return the train-test indices of each cv split.
:pr:`25659` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.naive_bayes`
..........................

Expand Down
22 changes: 21 additions & 1 deletion sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def cross_validate(
pre_dispatch="2*n_jobs",
return_train_score=False,
return_estimator=False,
return_indices=False,
error_score=np.nan,
):
"""Evaluate metric(s) by cross-validation and also record fit/score times.
Expand Down Expand Up @@ -169,6 +170,11 @@ def cross_validate(

.. versionadded:: 0.20

return_indices : bool, default=False
Whether to return the train-test indices selected for each split.

.. versionadded:: 1.3

error_score : 'raise' or numeric, default=np.nan
Value to assign to the score if an error occurs in estimator fitting.
If set to 'raise', the error is raised.
Expand Down Expand Up @@ -207,6 +213,11 @@ def cross_validate(
The estimator objects for each cv split.
This is available only if ``return_estimator`` parameter
is set to ``True``.
``indices``
The train/test positional indices for each cv split. A dictionary
is returned where the keys are either `"train"` or `"test"`
and the associated values are a list of integer-dtyped NumPy
arrays with the indices. Available only if `return_indices=True`.

See Also
--------
Expand Down Expand Up @@ -260,6 +271,11 @@ def cross_validate(
else:
scorers = _check_multimetric_scoring(estimator, scoring)

indices = cv.split(X, y, groups)
if return_indices:
# materialize the indices since we need to store them in the returned dict
indices = list(indices)

# We clone the estimator to make sure that all the folds are
# independent, and that it is pickle-able.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch)
Expand All @@ -279,7 +295,7 @@ def cross_validate(
return_estimator=return_estimator,
error_score=error_score,
)
for train, test in cv.split(X, y, groups)
for train, test in indices
)

_warn_or_raise_about_fit_failures(results, error_score)
Expand All @@ -299,6 +315,10 @@ def cross_validate(
if return_estimator:
ret["estimator"] = results["estimator"]

if return_indices:
ret["indices"] = {}
ret["indices"]["train"], ret["indices"]["test"] = zip(*indices)

test_scores_dict = _normalize_score_results(results["test_scores"])
if return_train_score:
train_scores_dict = _normalize_score_results(results["train_scores"])
Expand Down
27 changes: 26 additions & 1 deletion sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

from sklearn.impute import SimpleImputer

from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import LabelEncoder, scale
from sklearn.pipeline import Pipeline

from io import StringIO
Expand Down Expand Up @@ -2399,3 +2399,28 @@ def test_learning_curve_partial_fit_regressors():

# Does not error
learning_curve(MLPRegressor(), X, y, exploit_incremental_learning=True, cv=2)


def test_cross_validate_return_indices(global_random_seed):
"""Check the behaviour of `return_indices` in `cross_validate`."""
X, y = load_iris(return_X_y=True)
X = scale(X) # scale features for better convergence
estimator = LogisticRegression()

cv = KFold(n_splits=3, shuffle=True, random_state=global_random_seed)
cv_results = cross_validate(estimator, X, y, cv=cv, n_jobs=2, return_indices=False)
assert "indices" not in cv_results

cv_results = cross_validate(estimator, X, y, cv=cv, n_jobs=2, return_indices=True)
assert "indices" in cv_results
train_indices = cv_results["indices"]["train"]
test_indices = cv_results["indices"]["test"]
assert len(train_indices) == cv.n_splits
assert len(test_indices) == cv.n_splits

assert_array_equal([indices.size for indices in train_indices], 100)
assert_array_equal([indices.size for indices in test_indices], 50)

for split_idx, (expected_train_idx, expected_test_idx) in enumerate(cv.split(X, y)):
assert_array_equal(train_indices[split_idx], expected_train_idx)
assert_array_equal(test_indices[split_idx], expected_test_idx)