Skip to content

Commit af1796e

Browse files
neerajgangwaramueller
authored andcommitted
[MRG+1] Repeated K-Fold and Repeated Stratified K-Fold (#8120)
* Add _RepeatedSplits and RepeatedKFold class * Add RepeatedStratifiedKFold and doc for repeated cvs * Change default value of n_repeats * Change input parameters of repeated cv constructor to n_splits, n_repeats, random_state * Generate random states in split function rather than store it beforehand * Doc changes, inheriting RepeatedKFold, RepeatedStratifiedKFold from _RepeatedSplits and other review changes * Remove blank line, put testcases for deterministic split in loop and add StopIteration check in testcase * Using rng directly as random_state param to create cv instance and added a check for cvargs * Fix pep8 warnings * Changing default values for n_splits and n_repeats and add entry in changelog * Adding name to the feature * Missing space
1 parent fad5f9b commit af1796e

File tree

6 files changed

+286
-0
lines changed

6 files changed

+286
-0
lines changed

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ Splitter Classes
170170
model_selection.LeavePGroupsOut
171171
model_selection.LeaveOneOut
172172
model_selection.LeavePOut
173+
model_selection.RepeatedKFold
174+
model_selection.RepeatedStratifiedKFold
173175
model_selection.ShuffleSplit
174176
model_selection.GroupShuffleSplit
175177
model_selection.StratifiedShuffleSplit

doc/modules/cross_validation.rst

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,33 @@ Thus, one can create the training/test sets using numpy indexing::
263263
>>> X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test]
264264

265265

266+
Repeated K-Fold
267+
---------------
268+
269+
:class:`RepeatedKFold` repeats K-Fold n times. It can be used when one
270+
requires to run :class:`KFold` n times, producing different splits in
271+
each repetition.
272+
273+
Example of 2-fold K-Fold repeated 2 times::
274+
275+
>>> import numpy as np
276+
>>> from sklearn.model_selection import RepeatedKFold
277+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
278+
>>> random_state = 12883823
279+
>>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state)
280+
>>> for train, test in rkf.split(X):
281+
... print("%s %s" % (train, test))
282+
...
283+
[2 3] [0 1]
284+
[0 1] [2 3]
285+
[0 2] [1 3]
286+
[1 3] [0 2]
287+
288+
289+
Similarly, :class:`RepeatedStratifiedKFold` repeats Stratified K-Fold n times
290+
with different randomization in each repetition.
291+
292+
266293
Leave One Out (LOO)
267294
-------------------
268295

@@ -409,6 +436,10 @@ two slightly unbalanced classes::
409436
[0 1 3 4 5 8 9] [2 6 7]
410437
[0 1 2 4 5 6 7] [3 8 9]
411438

439+
:class:`RepeatedStratifiedKFold` can be used to repeat Stratified K-Fold n times
440+
with different randomization in each repetition.
441+
442+
412443
Stratified Shuffle Split
413444
------------------------
414445

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ New features
4141
Kullback-Leibler divergence and the Itakura-Saito divergence.
4242
By `Tom Dupre la Tour`_.
4343

44+
- Added the :class:`sklearn.model_selection.RepeatedKFold` and
45+
:class:`sklearn.model_selection.RepeatedStratifiedKFold`.
46+
:issue:`8120` by `Neeraj Gangwar`_.
47+
4448
- Added :func:`metrics.mean_squared_log_error`, which computes
4549
the mean square error of the logarithmic transformation of targets,
4650
particularly useful for targets with an exponential trend.
@@ -5004,3 +5008,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
50045008
.. _Vincent Pham: https://github.com/vincentpham1991
50055009

50065010
.. _Denis Engemann: http://denis-engemann.de
5011+
5012+
.. _Neeraj Gangwar: http://neerajgangwar.in

sklearn/model_selection/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from ._split import LeaveOneOut
88
from ._split import LeavePGroupsOut
99
from ._split import LeavePOut
10+
from ._split import RepeatedKFold
11+
from ._split import RepeatedStratifiedKFold
1012
from ._split import ShuffleSplit
1113
from ._split import GroupShuffleSplit
1214
from ._split import StratifiedShuffleSplit
@@ -36,6 +38,8 @@
3638
'LeaveOneOut',
3739
'LeavePGroupsOut',
3840
'LeavePOut',
41+
'RepeatedKFold',
42+
'RepeatedStratifiedKFold',
3943
'ParameterGrid',
4044
'ParameterSampler',
4145
'PredefinedSplit',

sklearn/model_selection/_split.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
'LeaveOneOut',
4242
'LeavePGroupsOut',
4343
'LeavePOut',
44+
'RepeatedStratifiedKFold',
45+
'RepeatedKFold',
4446
'ShuffleSplit',
4547
'GroupShuffleSplit',
4648
'StratifiedKFold',
@@ -397,6 +399,8 @@ class KFold(_BaseKFold):
397399
classification tasks).
398400
399401
GroupKFold: K-fold iterator variant with non-overlapping groups.
402+
403+
RepeatedKFold: Repeats K-Fold n times.
400404
"""
401405

402406
def __init__(self, n_splits=3, shuffle=False,
@@ -553,6 +557,9 @@ class StratifiedKFold(_BaseKFold):
553557
All the folds have size ``trunc(n_samples / n_splits)``, the last one has
554558
the complementary.
555559
560+
See also
561+
--------
562+
RepeatedStratifiedKFold: Repeats Stratified K-Fold n times.
556563
"""
557564

558565
def __init__(self, n_splits=3, shuffle=False, random_state=None):
@@ -913,6 +920,170 @@ def get_n_splits(self, X, y, groups):
913920
return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
914921

915922

923+
class _RepeatedSplits(with_metaclass(ABCMeta)):
924+
"""Repeated splits for an arbitrary randomized CV splitter.
925+
926+
Repeats splits for cross-validators n times with different randomization
927+
in each repetition.
928+
929+
Parameters
930+
----------
931+
cv : callable
932+
Cross-validator class.
933+
934+
n_repeats : int, default=10
935+
Number of times cross-validator needs to be repeated.
936+
937+
random_state : None, int or RandomState, default=None
938+
Random state to be used to generate random state for each
939+
repetition.
940+
941+
**cvargs : additional params
942+
Constructor parameters for cv. Must not contain random_state
943+
and shuffle.
944+
"""
945+
def __init__(self, cv, n_repeats=10, random_state=None, **cvargs):
946+
if not isinstance(n_repeats, (np.integer, numbers.Integral)):
947+
raise ValueError("Number of repetitions must be of Integral type.")
948+
949+
if n_repeats <= 1:
950+
raise ValueError("Number of repetitions must be greater than 1.")
951+
952+
if any(key in cvargs for key in ('random_state', 'shuffle')):
953+
raise ValueError(
954+
"cvargs must not contain random_state or shuffle.")
955+
956+
self.cv = cv
957+
self.n_repeats = n_repeats
958+
self.random_state = random_state
959+
self.cvargs = cvargs
960+
961+
def split(self, X, y=None, groups=None):
962+
"""Generates indices to split data into training and test set.
963+
964+
Parameters
965+
----------
966+
X : array-like, shape (n_samples, n_features)
967+
Training data, where n_samples is the number of samples
968+
and n_features is the number of features.
969+
970+
y : array-like, of length n_samples
971+
The target variable for supervised learning problems.
972+
973+
groups : array-like, with shape (n_samples,), optional
974+
Group labels for the samples used while splitting the dataset into
975+
train/test set.
976+
977+
Returns
978+
-------
979+
train : ndarray
980+
The training set indices for that split.
981+
982+
test : ndarray
983+
The testing set indices for that split.
984+
"""
985+
n_repeats = self.n_repeats
986+
rng = check_random_state(self.random_state)
987+
988+
for idx in range(n_repeats):
989+
cv = self.cv(random_state=rng, shuffle=True,
990+
**self.cvargs)
991+
for train_index, test_index in cv.split(X, y, groups):
992+
yield train_index, test_index
993+
994+
995+
class RepeatedKFold(_RepeatedSplits):
996+
"""Repeated K-Fold cross validator.
997+
998+
Repeats K-Fold n times with different randomization in each repetition.
999+
1000+
Read more in the :ref:`User Guide <cross_validation>`.
1001+
1002+
Parameters
1003+
----------
1004+
n_splits : int, default=5
1005+
Number of folds. Must be at least 2.
1006+
1007+
n_repeats : int, default=10
1008+
Number of times cross-validator needs to be repeated.
1009+
1010+
random_state : None, int or RandomState, default=None
1011+
Random state to be used to generate random state for each
1012+
repetition.
1013+
1014+
Examples
1015+
--------
1016+
>>> from sklearn.model_selection import RepeatedKFold
1017+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
1018+
>>> y = np.array([0, 0, 1, 1])
1019+
>>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124)
1020+
>>> for train_index, test_index in rkf.split(X):
1021+
... print("TRAIN:", train_index, "TEST:", test_index)
1022+
... X_train, X_test = X[train_index], X[test_index]
1023+
... y_train, y_test = y[train_index], y[test_index]
1024+
...
1025+
TRAIN: [0 1] TEST: [2 3]
1026+
TRAIN: [2 3] TEST: [0 1]
1027+
TRAIN: [1 2] TEST: [0 3]
1028+
TRAIN: [0 3] TEST: [1 2]
1029+
1030+
1031+
See also
1032+
--------
1033+
RepeatedStratifiedKFold: Repeates Stratified K-Fold n times.
1034+
"""
1035+
def __init__(self, n_splits=5, n_repeats=10, random_state=None):
1036+
super(RepeatedKFold, self).__init__(
1037+
KFold, n_repeats, random_state, n_splits=n_splits)
1038+
1039+
1040+
class RepeatedStratifiedKFold(_RepeatedSplits):
1041+
"""Repeated Stratified K-Fold cross validator.
1042+
1043+
Repeats Stratified K-Fold n times with different randomization in each
1044+
repetition.
1045+
1046+
Read more in the :ref:`User Guide <cross_validation>`.
1047+
1048+
Parameters
1049+
----------
1050+
n_splits : int, default=5
1051+
Number of folds. Must be at least 2.
1052+
1053+
n_repeats : int, default=10
1054+
Number of times cross-validator needs to be repeated.
1055+
1056+
random_state : None, int or RandomState, default=None
1057+
Random state to be used to generate random state for each
1058+
repetition.
1059+
1060+
Examples
1061+
--------
1062+
>>> from sklearn.model_selection import RepeatedStratifiedKFold
1063+
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
1064+
>>> y = np.array([0, 0, 1, 1])
1065+
>>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2,
1066+
... random_state=36851234)
1067+
>>> for train_index, test_index in rskf.split(X, y):
1068+
... print("TRAIN:", train_index, "TEST:", test_index)
1069+
... X_train, X_test = X[train_index], X[test_index]
1070+
... y_train, y_test = y[train_index], y[test_index]
1071+
...
1072+
TRAIN: [1 2] TEST: [0 3]
1073+
TRAIN: [0 3] TEST: [1 2]
1074+
TRAIN: [1 3] TEST: [0 2]
1075+
TRAIN: [0 2] TEST: [1 3]
1076+
1077+
1078+
See also
1079+
--------
1080+
RepeatedKFold: Repeats K-Fold n times.
1081+
"""
1082+
def __init__(self, n_splits=5, n_repeats=10, random_state=None):
1083+
super(RepeatedStratifiedKFold, self).__init__(
1084+
StratifiedKFold, n_repeats, random_state, n_splits=n_splits)
1085+
1086+
9161087
class BaseShuffleSplit(with_metaclass(ABCMeta)):
9171088
"""Base class for ShuffleSplit and StratifiedShuffleSplit"""
9181089

sklearn/model_selection/tests/test_split.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
from sklearn.model_selection import check_cv
4343
from sklearn.model_selection import train_test_split
4444
from sklearn.model_selection import GridSearchCV
45+
from sklearn.model_selection import RepeatedKFold
46+
from sklearn.model_selection import RepeatedStratifiedKFold
4547

4648
from sklearn.linear_model import Ridge
4749

@@ -804,6 +806,76 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
804806
LeavePGroupsOut(n_groups=3).split(X, y, groups))
805807

806808

809+
def test_repeated_cv_value_errors():
810+
# n_repeats is not integer or <= 1
811+
for cv in (RepeatedKFold, RepeatedStratifiedKFold):
812+
assert_raises(ValueError, cv, n_repeats=1)
813+
assert_raises(ValueError, cv, n_repeats=1.5)
814+
815+
816+
def test_repeated_kfold_determinstic_split():
817+
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
818+
random_state = 258173307
819+
rkf = RepeatedKFold(
820+
n_splits=2,
821+
n_repeats=2,
822+
random_state=random_state)
823+
824+
# split should produce same and deterministic splits on
825+
# each call
826+
for _ in range(3):
827+
splits = rkf.split(X)
828+
train, test = next(splits)
829+
assert_array_equal(train, [2, 4])
830+
assert_array_equal(test, [0, 1, 3])
831+
832+
train, test = next(splits)
833+
assert_array_equal(train, [0, 1, 3])
834+
assert_array_equal(test, [2, 4])
835+
836+
train, test = next(splits)
837+
assert_array_equal(train, [0, 1])
838+
assert_array_equal(test, [2, 3, 4])
839+
840+
train, test = next(splits)
841+
assert_array_equal(train, [2, 3, 4])
842+
assert_array_equal(test, [0, 1])
843+
844+
assert_raises(StopIteration, next, splits)
845+
846+
847+
def test_repeated_stratified_kfold_determinstic_split():
848+
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
849+
y = [1, 1, 1, 0, 0]
850+
random_state = 1944695409
851+
rskf = RepeatedStratifiedKFold(
852+
n_splits=2,
853+
n_repeats=2,
854+
random_state=random_state)
855+
856+
# split should produce same and deterministic splits on
857+
# each call
858+
for _ in range(3):
859+
splits = rskf.split(X, y)
860+
train, test = next(splits)
861+
assert_array_equal(train, [1, 4])
862+
assert_array_equal(test, [0, 2, 3])
863+
864+
train, test = next(splits)
865+
assert_array_equal(train, [0, 2, 3])
866+
assert_array_equal(test, [1, 4])
867+
868+
train, test = next(splits)
869+
assert_array_equal(train, [2, 3])
870+
assert_array_equal(test, [0, 1, 4])
871+
872+
train, test = next(splits)
873+
assert_array_equal(train, [0, 1, 4])
874+
assert_array_equal(test, [2, 3])
875+
876+
assert_raises(StopIteration, next, splits)
877+
878+
807879
def test_train_test_split_errors():
808880
assert_raises(ValueError, train_test_split)
809881
assert_raises(ValueError, train_test_split, range(3), train_size=1.1)

0 commit comments

Comments
 (0)