-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Repeated K-Fold and Repeated Stratified K-Fold #8120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3742d69
42f711f
ad2753c
e62bc27
c631b31
5e3078d
373e5a5
cba817a
218e409
f9f48a8
feb9172
eb38ebe
338997e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,8 @@ | |
'LeaveOneOut', | ||
'LeavePGroupsOut', | ||
'LeavePOut', | ||
'RepeatedStratifiedKFold', | ||
'RepeatedKFold', | ||
'ShuffleSplit', | ||
'GroupShuffleSplit', | ||
'StratifiedKFold', | ||
|
@@ -397,6 +399,8 @@ class KFold(_BaseKFold): | |
classification tasks). | ||
|
||
GroupKFold: K-fold iterator variant with non-overlapping groups. | ||
|
||
RepeatedKFold: Repeats K-Fold n times. | ||
""" | ||
|
||
def __init__(self, n_splits=3, shuffle=False, | ||
|
@@ -553,6 +557,9 @@ class StratifiedKFold(_BaseKFold): | |
All the folds have size ``trunc(n_samples / n_splits)``, the last one has | ||
the complementary. | ||
|
||
See also | ||
-------- | ||
RepeatedStratifiedKFold: Repeats Stratified K-Fold n times. | ||
""" | ||
|
||
def __init__(self, n_splits=3, shuffle=False, random_state=None): | ||
|
@@ -913,6 +920,170 @@ def get_n_splits(self, X, y, groups): | |
return int(comb(len(np.unique(groups)), self.n_groups, exact=True)) | ||
|
||
|
||
class _RepeatedSplits(with_metaclass(ABCMeta)): | ||
"""Repeated splits for an arbitrary randomized CV splitter. | ||
|
||
Repeats splits for cross-validators n times with different randomization | ||
in each repetition. | ||
|
||
Parameters | ||
---------- | ||
cv : callable | ||
Cross-validator class. | ||
|
||
n_repeats : int, default=10 | ||
Number of times cross-validator needs to be repeated. | ||
|
||
random_state : None, int or RandomState, default=None | ||
Random state to be used to generate random state for each | ||
repetition. | ||
|
||
**cvargs : additional params | ||
Constructor parameters for cv. Must not contain random_state | ||
and shuffle. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not an obligation as _RepeatedSplits is private but can you raise an error in split to check that ? |
||
""" | ||
def __init__(self, cv, n_repeats=10, random_state=None, **cvargs): | ||
if not isinstance(n_repeats, (np.integer, numbers.Integral)): | ||
raise ValueError("Number of repetitions must be of Integral type.") | ||
|
||
if n_repeats <= 1: | ||
raise ValueError("Number of repetitions must be greater than 1.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Never check values in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't error be thrown at the construction time if there is some discrepancy with the parameters passed? In There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In sklearn for the estimator, we never check the error in @jnothman As I'm not 100% sure can you confirm that ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, at least, CV splitters are a bit special in this regard. Checking in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok thx @jnothman |
||
|
||
if any(key in cvargs for key in ('random_state', 'shuffle')): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping the same as both are of equal length. :P |
||
raise ValueError( | ||
"cvargs must not contain random_state or shuffle.") | ||
|
||
self.cv = cv | ||
self.n_repeats = n_repeats | ||
self.random_state = random_state | ||
self.cvargs = cvargs | ||
|
||
def split(self, X, y=None, groups=None): | ||
"""Generates indices to split data into training and test set. | ||
|
||
Parameters | ||
---------- | ||
X : array-like, shape (n_samples, n_features) | ||
Training data, where n_samples is the number of samples | ||
and n_features is the number of features. | ||
|
||
y : array-like, of length n_samples | ||
The target variable for supervised learning problems. | ||
|
||
groups : array-like, with shape (n_samples,), optional | ||
Group labels for the samples used while splitting the dataset into | ||
train/test set. | ||
|
||
Returns | ||
------- | ||
train : ndarray | ||
The training set indices for that split. | ||
|
||
test : ndarray | ||
The testing set indices for that split. | ||
""" | ||
n_repeats = self.n_repeats | ||
rng = check_random_state(self.random_state) | ||
|
||
for idx in range(n_repeats): | ||
cv = self.cv(random_state=rng, shuffle=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we maybe want to raise nice errors if these arguments are not present? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't get you. Which arguments? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly I have no idea what I meant.... ? |
||
**self.cvargs) | ||
for train_index, test_index in cv.split(X, y, groups): | ||
yield train_index, test_index | ||
|
||
|
||
class RepeatedKFold(_RepeatedSplits): | ||
"""Repeated K-Fold cross validator. | ||
|
||
Repeats K-Fold n times with different randomization in each repetition. | ||
|
||
Read more in the :ref:`User Guide <cross_validation>`. | ||
|
||
Parameters | ||
---------- | ||
n_splits : int, default=5 | ||
Number of folds. Must be at least 2. | ||
|
||
n_repeats : int, default=10 | ||
Number of times cross-validator needs to be repeated. | ||
|
||
random_state : None, int or RandomState, default=None | ||
Random state to be used to generate random state for each | ||
repetition. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.model_selection import RepeatedKFold | ||
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) | ||
>>> y = np.array([0, 0, 1, 1]) | ||
>>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124) | ||
>>> for train_index, test_index in rkf.split(X): | ||
... print("TRAIN:", train_index, "TEST:", test_index) | ||
... X_train, X_test = X[train_index], X[test_index] | ||
... y_train, y_test = y[train_index], y[test_index] | ||
... | ||
TRAIN: [0 1] TEST: [2 3] | ||
TRAIN: [2 3] TEST: [0 1] | ||
TRAIN: [1 2] TEST: [0 3] | ||
TRAIN: [0 3] TEST: [1 2] | ||
|
||
|
||
See also | ||
-------- | ||
RepeatedStratifiedKFold: Repeates Stratified K-Fold n times. | ||
""" | ||
def __init__(self, n_splits=5, n_repeats=10, random_state=None): | ||
super(RepeatedKFold, self).__init__( | ||
KFold, n_repeats, random_state, n_splits=n_splits) | ||
|
||
|
||
class RepeatedStratifiedKFold(_RepeatedSplits): | ||
"""Repeated Stratified K-Fold cross validator. | ||
|
||
Repeats Stratified K-Fold n times with different randomization in each | ||
repetition. | ||
|
||
Read more in the :ref:`User Guide <cross_validation>`. | ||
|
||
Parameters | ||
---------- | ||
n_splits : int, default=5 | ||
Number of folds. Must be at least 2. | ||
|
||
n_repeats : int, default=10 | ||
Number of times cross-validator needs to be repeated. | ||
|
||
random_state : None, int or RandomState, default=None | ||
Random state to be used to generate random state for each | ||
repetition. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.model_selection import RepeatedStratifiedKFold | ||
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) | ||
>>> y = np.array([0, 0, 1, 1]) | ||
>>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, | ||
... random_state=36851234) | ||
>>> for train_index, test_index in rskf.split(X, y): | ||
... print("TRAIN:", train_index, "TEST:", test_index) | ||
... X_train, X_test = X[train_index], X[test_index] | ||
... y_train, y_test = y[train_index], y[test_index] | ||
... | ||
TRAIN: [1 2] TEST: [0 3] | ||
TRAIN: [0 3] TEST: [1 2] | ||
TRAIN: [1 3] TEST: [0 2] | ||
TRAIN: [0 2] TEST: [1 3] | ||
|
||
|
||
See also | ||
-------- | ||
RepeatedKFold: Repeats K-Fold n times. | ||
""" | ||
def __init__(self, n_splits=5, n_repeats=10, random_state=None): | ||
super(RepeatedStratifiedKFold, self).__init__( | ||
StratifiedKFold, n_repeats, random_state, n_splits=n_splits) | ||
|
||
|
||
class BaseShuffleSplit(with_metaclass(ABCMeta)): | ||
"""Base class for ShuffleSplit and StratifiedShuffleSplit""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say "it can be used to run KFold multiple times to increase the fidelity of the estimate? Or can we say to decrease the variance? Is that accurate?