diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index fbc00f3069e51..113a015c2bbca 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1706,12 +1706,19 @@ def _validate_shuffle_split(n_samples, test_size, train_size): class PredefinedSplit(BaseCrossValidator): """Predefined split cross-validator - Splits the data into training/test set folds according to a predefined - scheme. Each sample can be assigned to at most one test set fold, as - specified by the user through the ``test_fold`` parameter. + Provides train/test indices to split data into train/test sets using a + predefined scheme specified by the user with the ``test_fold`` parameter. Read more in the :ref:`User Guide `. + Parameters + ---------- + test_fold : array-like, shape (n_samples,) + The entry ``test_fold[i]`` represents the index of the test set that + sample ``i`` belongs to. It is possible to exclude sample ``i`` from + any test set (i.e. include sample ``i`` in every training set) by + setting ``test_fold[i]`` equal to -1. + Examples -------- >>> from sklearn.model_selection import PredefinedSplit