|
41 | 41 | 'LeaveOneOut',
|
42 | 42 | 'LeavePGroupsOut',
|
43 | 43 | 'LeavePOut',
|
| 44 | + 'RepeatedStratifiedKFold', |
| 45 | + 'RepeatedKFold', |
44 | 46 | 'ShuffleSplit',
|
45 | 47 | 'GroupShuffleSplit',
|
46 | 48 | 'StratifiedKFold',
|
@@ -397,6 +399,8 @@ class KFold(_BaseKFold):
|
397 | 399 | classification tasks).
|
398 | 400 |
|
399 | 401 | GroupKFold: K-fold iterator variant with non-overlapping groups.
|
| 402 | +
|
| 403 | + RepeatedKFold: Repeats K-Fold n times. |
400 | 404 | """
|
401 | 405 |
|
402 | 406 | def __init__(self, n_splits=3, shuffle=False,
|
@@ -553,6 +557,9 @@ class StratifiedKFold(_BaseKFold):
|
553 | 557 | All the folds have size ``trunc(n_samples / n_splits)``, the last one has
|
554 | 558 | the complementary.
|
555 | 559 |
|
| 560 | + See also |
| 561 | + -------- |
| 562 | + RepeatedStratifiedKFold: Repeats Stratified K-Fold n times. |
556 | 563 | """
|
557 | 564 |
|
558 | 565 | def __init__(self, n_splits=3, shuffle=False, random_state=None):
|
@@ -913,6 +920,170 @@ def get_n_splits(self, X, y, groups):
|
913 | 920 | return int(comb(len(np.unique(groups)), self.n_groups, exact=True))
|
914 | 921 |
|
915 | 922 |
|
| 923 | +class _RepeatedSplits(with_metaclass(ABCMeta)): |
| 924 | + """Repeated splits for an arbitrary randomized CV splitter. |
| 925 | +
|
| 926 | + Repeats splits for cross-validators n times with different randomization |
| 927 | + in each repetition. |
| 928 | +
|
| 929 | + Parameters |
| 930 | + ---------- |
| 931 | + cv : callable |
| 932 | + Cross-validator class. |
| 933 | +
|
| 934 | + n_repeats : int, default=10 |
| 935 | + Number of times cross-validator needs to be repeated. |
| 936 | +
|
| 937 | + random_state : None, int or RandomState, default=None |
| 938 | + Random state to be used to generate random state for each |
| 939 | + repetition. |
| 940 | +
|
| 941 | + **cvargs : additional params |
| 942 | + Constructor parameters for cv. Must not contain random_state |
| 943 | + and shuffle. |
| 944 | + """ |
| 945 | + def __init__(self, cv, n_repeats=10, random_state=None, **cvargs): |
| 946 | + if not isinstance(n_repeats, (np.integer, numbers.Integral)): |
| 947 | + raise ValueError("Number of repetitions must be of Integral type.") |
| 948 | + |
| 949 | + if n_repeats <= 1: |
| 950 | + raise ValueError("Number of repetitions must be greater than 1.") |
| 951 | + |
| 952 | + if any(key in cvargs for key in ('random_state', 'shuffle')): |
| 953 | + raise ValueError( |
| 954 | + "cvargs must not contain random_state or shuffle.") |
| 955 | + |
| 956 | + self.cv = cv |
| 957 | + self.n_repeats = n_repeats |
| 958 | + self.random_state = random_state |
| 959 | + self.cvargs = cvargs |
| 960 | + |
| 961 | + def split(self, X, y=None, groups=None): |
| 962 | + """Generates indices to split data into training and test set. |
| 963 | +
|
| 964 | + Parameters |
| 965 | + ---------- |
| 966 | + X : array-like, shape (n_samples, n_features) |
| 967 | + Training data, where n_samples is the number of samples |
| 968 | + and n_features is the number of features. |
| 969 | +
|
| 970 | + y : array-like, of length n_samples |
| 971 | + The target variable for supervised learning problems. |
| 972 | +
|
| 973 | + groups : array-like, with shape (n_samples,), optional |
| 974 | + Group labels for the samples used while splitting the dataset into |
| 975 | + train/test set. |
| 976 | +
|
| 977 | + Returns |
| 978 | + ------- |
| 979 | + train : ndarray |
| 980 | + The training set indices for that split. |
| 981 | +
|
| 982 | + test : ndarray |
| 983 | + The testing set indices for that split. |
| 984 | + """ |
| 985 | + n_repeats = self.n_repeats |
| 986 | + rng = check_random_state(self.random_state) |
| 987 | + |
| 988 | + for idx in range(n_repeats): |
| 989 | + cv = self.cv(random_state=rng, shuffle=True, |
| 990 | + **self.cvargs) |
| 991 | + for train_index, test_index in cv.split(X, y, groups): |
| 992 | + yield train_index, test_index |
| 993 | + |
| 994 | + |
| 995 | +class RepeatedKFold(_RepeatedSplits): |
| 996 | + """Repeated K-Fold cross validator. |
| 997 | +
|
| 998 | + Repeats K-Fold n times with different randomization in each repetition. |
| 999 | +
|
| 1000 | + Read more in the :ref:`User Guide <cross_validation>`. |
| 1001 | +
|
| 1002 | + Parameters |
| 1003 | + ---------- |
| 1004 | + n_splits : int, default=5 |
| 1005 | + Number of folds. Must be at least 2. |
| 1006 | +
|
| 1007 | + n_repeats : int, default=10 |
| 1008 | + Number of times cross-validator needs to be repeated. |
| 1009 | +
|
| 1010 | + random_state : None, int or RandomState, default=None |
| 1011 | + Random state to be used to generate random state for each |
| 1012 | + repetition. |
| 1013 | +
|
| 1014 | + Examples |
| 1015 | + -------- |
| 1016 | + >>> from sklearn.model_selection import RepeatedKFold |
| 1017 | + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) |
| 1018 | + >>> y = np.array([0, 0, 1, 1]) |
| 1019 | + >>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124) |
| 1020 | + >>> for train_index, test_index in rkf.split(X): |
| 1021 | + ... print("TRAIN:", train_index, "TEST:", test_index) |
| 1022 | + ... X_train, X_test = X[train_index], X[test_index] |
| 1023 | + ... y_train, y_test = y[train_index], y[test_index] |
| 1024 | + ... |
| 1025 | + TRAIN: [0 1] TEST: [2 3] |
| 1026 | + TRAIN: [2 3] TEST: [0 1] |
| 1027 | + TRAIN: [1 2] TEST: [0 3] |
| 1028 | + TRAIN: [0 3] TEST: [1 2] |
| 1029 | +
|
| 1030 | +
|
| 1031 | + See also |
| 1032 | + -------- |
| 1033 | + RepeatedStratifiedKFold: Repeates Stratified K-Fold n times. |
| 1034 | + """ |
| 1035 | + def __init__(self, n_splits=5, n_repeats=10, random_state=None): |
| 1036 | + super(RepeatedKFold, self).__init__( |
| 1037 | + KFold, n_repeats, random_state, n_splits=n_splits) |
| 1038 | + |
| 1039 | + |
| 1040 | +class RepeatedStratifiedKFold(_RepeatedSplits): |
| 1041 | + """Repeated Stratified K-Fold cross validator. |
| 1042 | +
|
| 1043 | + Repeats Stratified K-Fold n times with different randomization in each |
| 1044 | + repetition. |
| 1045 | +
|
| 1046 | + Read more in the :ref:`User Guide <cross_validation>`. |
| 1047 | +
|
| 1048 | + Parameters |
| 1049 | + ---------- |
| 1050 | + n_splits : int, default=5 |
| 1051 | + Number of folds. Must be at least 2. |
| 1052 | +
|
| 1053 | + n_repeats : int, default=10 |
| 1054 | + Number of times cross-validator needs to be repeated. |
| 1055 | +
|
| 1056 | + random_state : None, int or RandomState, default=None |
| 1057 | + Random state to be used to generate random state for each |
| 1058 | + repetition. |
| 1059 | +
|
| 1060 | + Examples |
| 1061 | + -------- |
| 1062 | + >>> from sklearn.model_selection import RepeatedStratifiedKFold |
| 1063 | + >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) |
| 1064 | + >>> y = np.array([0, 0, 1, 1]) |
| 1065 | + >>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, |
| 1066 | + ... random_state=36851234) |
| 1067 | + >>> for train_index, test_index in rskf.split(X, y): |
| 1068 | + ... print("TRAIN:", train_index, "TEST:", test_index) |
| 1069 | + ... X_train, X_test = X[train_index], X[test_index] |
| 1070 | + ... y_train, y_test = y[train_index], y[test_index] |
| 1071 | + ... |
| 1072 | + TRAIN: [1 2] TEST: [0 3] |
| 1073 | + TRAIN: [0 3] TEST: [1 2] |
| 1074 | + TRAIN: [1 3] TEST: [0 2] |
| 1075 | + TRAIN: [0 2] TEST: [1 3] |
| 1076 | +
|
| 1077 | +
|
| 1078 | + See also |
| 1079 | + -------- |
| 1080 | + RepeatedKFold: Repeats K-Fold n times. |
| 1081 | + """ |
| 1082 | + def __init__(self, n_splits=5, n_repeats=10, random_state=None): |
| 1083 | + super(RepeatedStratifiedKFold, self).__init__( |
| 1084 | + StratifiedKFold, n_repeats, random_state, n_splits=n_splits) |
| 1085 | + |
| 1086 | + |
916 | 1087 | class BaseShuffleSplit(with_metaclass(ABCMeta)):
|
917 | 1088 | """Base class for ShuffleSplit and StratifiedShuffleSplit"""
|
918 | 1089 |
|
|
0 commit comments