-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+2] Label K-Fold #5190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[MRG+2] Label K-Fold #5190
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
from .utils.fixes import bincount | ||
|
||
__all__ = ['KFold', | ||
'LabelKFold', | ||
'LeaveOneLabelOut', | ||
'LeaveOneOut', | ||
'LeavePLabelOut', | ||
|
@@ -273,15 +274,15 @@ class KFold(_BaseKFold): | |
Whether to shuffle the data before splitting into batches. | ||
|
||
random_state : None, int or RandomState | ||
When shuffle=True, pseudo-random number generator state used for | ||
When shuffle=True, pseudo-random number generator state used for | ||
shuffling. If None, use default numpy RNG for shuffling. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn import cross_validation | ||
>>> from sklearn.cross_validation import KFold | ||
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) | ||
>>> y = np.array([1, 2, 3, 4]) | ||
>>> kf = cross_validation.KFold(4, n_folds=2) | ||
>>> kf = KFold(4, n_folds=2) | ||
>>> len(kf) | ||
2 | ||
>>> print(kf) # doctest: +NORMALIZE_WHITESPACE | ||
|
@@ -304,6 +305,8 @@ class KFold(_BaseKFold): | |
StratifiedKFold: take label information into account to avoid building | ||
folds with imbalanced class distributions (for binary or multiclass | ||
classification tasks). | ||
|
||
LabelKFold: K-fold iterator variant with non-overlapping labels. | ||
""" | ||
|
||
def __init__(self, n, n_folds=3, shuffle=False, | ||
|
@@ -339,6 +342,117 @@ def __len__(self): | |
return self.n_folds | ||
|
||
|
||
class LabelKFold(_BaseKFold): | ||
"""K-fold iterator variant with non-overlapping labels. | ||
|
||
The same label will not appear in two different folds (the number of | ||
distinct labels has to be at least equal to the number of folds). | ||
|
||
The folds are approximately balanced in the sense that the number of | ||
distinct labels is approximately the same in each fold. | ||
|
||
Parameters | ||
---------- | ||
labels : array-like with shape (n_samples, ) | ||
Contains a label for each sample. | ||
The folds are built so that the same label does not appear in two | ||
different folds. | ||
|
||
n_folds : int, default=3 | ||
Number of folds. Must be at least 2. | ||
|
||
shuffle : boolean, optional | ||
Whether to shuffle the data before splitting into batches. | ||
|
||
random_state : None, int or RandomState | ||
When shuffle=True, pseudo-random number generator state used for | ||
shuffling. If None, use default numpy RNG for shuffling. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.cross_validation import LabelKFold | ||
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) | ||
>>> y = np.array([1, 2, 3, 4]) | ||
>>> labels = np.array([0, 0, 2, 2]) | ||
>>> label_kfold = LabelKFold(labels, n_folds=2) | ||
>>> len(label_kfold) | ||
2 | ||
>>> print(label_kfold) | ||
sklearn.cross_validation.LabelKFold(n_labels=4, n_folds=2) | ||
>>> for train_index, test_index in label_kfold: | ||
... print("TRAIN:", train_index, "TEST:", test_index) | ||
... X_train, X_test = X[train_index], X[test_index] | ||
... y_train, y_test = y[train_index], y[test_index] | ||
... print(X_train, X_test, y_train, y_test) | ||
... | ||
TRAIN: [0 1] TEST: [2 3] | ||
[[1 2] | ||
[3 4]] [[5 6] | ||
[7 8]] [1 2] [3 4] | ||
TRAIN: [2 3] TEST: [0 1] | ||
[[5 6] | ||
[7 8]] [[1 2] | ||
[3 4]] [3 4] [1 2] | ||
|
||
See also | ||
-------- | ||
LeaveOneLabelOut for splitting the data according to explicit, | ||
domain-specific stratification of the dataset. | ||
""" | ||
def __init__(self, labels, n_folds=3, shuffle=False, random_state=None): | ||
super(LabelKFold, self).__init__(len(labels), n_folds, shuffle, | ||
random_state) | ||
|
||
unique_labels, labels = np.unique(labels, return_inverse=True) | ||
n_labels = len(unique_labels) | ||
|
||
if n_folds > n_labels: | ||
raise ValueError( | ||
("Cannot have number of folds n_folds={0} greater" | ||
" than the number of labels: {1}.").format(n_folds, | ||
n_labels)) | ||
|
||
# Weight labels by their number of occurences | ||
n_samples_per_label = np.bincount(labels) | ||
|
||
# Distribute the most frequent labels first | ||
indices = np.argsort(n_samples_per_label)[::-1] | ||
n_samples_per_label = n_samples_per_label[indices] | ||
|
||
# Total weight of each fold | ||
n_samples_per_fold = np.zeros(n_folds) | ||
|
||
# Mapping from label index to fold index | ||
label_to_fold = np.zeros(len(unique_labels)) | ||
|
||
# Distribute samples by adding the largest weight to the lightest fold | ||
for label_index, weight in enumerate(n_samples_per_label): | ||
lightest_fold = np.argmin(n_samples_per_fold) | ||
n_samples_per_fold[lightest_fold] += weight | ||
label_to_fold[indices[label_index]] = lightest_fold | ||
|
||
self.idxs = label_to_fold[labels] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idxs -> indexs or something more coherent with the rest of the module? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, this comes from the |
||
|
||
if shuffle: | ||
rng = check_random_state(self.random_state) | ||
rng.shuffle(self.idxs) | ||
|
||
def _iter_test_indices(self): | ||
for i in range(self.n_folds): | ||
yield (self.idxs == i) | ||
|
||
def __repr__(self): | ||
return '{0}.{1}(n_labels={2}, n_folds={3})'.format( | ||
self.__class__.__module__, | ||
self.__class__.__name__, | ||
self.n, | ||
self.n_folds, | ||
) | ||
|
||
def __len__(self): | ||
return self.n_folds | ||
|
||
|
||
class StratifiedKFold(_BaseKFold): | ||
"""Stratified K-Folds cross validation iterator | ||
|
||
|
@@ -363,15 +477,15 @@ class StratifiedKFold(_BaseKFold): | |
into batches. | ||
|
||
random_state : None, int or RandomState | ||
When shuffle=True, pseudo-random number generator state used for | ||
When shuffle=True, pseudo-random number generator state used for | ||
shuffling. If None, use default numpy RNG for shuffling. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn import cross_validation | ||
>>> from sklearn.cross_validation import StratifiedKFold | ||
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) | ||
>>> y = np.array([0, 0, 1, 1]) | ||
>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) | ||
>>> skf = StratifiedKFold(y, n_folds=2) | ||
>>> len(skf) | ||
2 | ||
>>> print(skf) # doctest: +NORMALIZE_WHITESPACE | ||
|
@@ -389,6 +503,9 @@ class StratifiedKFold(_BaseKFold): | |
All the folds have size trunc(n_samples / n_folds), the last one has the | ||
complementary. | ||
|
||
See also | ||
-------- | ||
LabelKFold: K-fold iterator variant with non-overlapping labels. | ||
""" | ||
|
||
def __init__(self, y, n_folds=3, shuffle=False, | ||
|
@@ -497,6 +614,9 @@ class LeaveOneLabelOut(_PartitionIterator): | |
[3 4]] [[5 6] | ||
[7 8]] [1 2] [1 2] | ||
|
||
See also | ||
-------- | ||
LabelKFold: K-fold iterator variant with non-overlapping labels. | ||
""" | ||
|
||
def __init__(self, labels): | ||
|
@@ -572,6 +692,10 @@ class LeavePLabelOut(_PartitionIterator): | |
TRAIN: [0] TEST: [1 2] | ||
[[1 2]] [[3 4] | ||
[5 6]] [1] [2 1] | ||
|
||
See also | ||
-------- | ||
LabelKFold: K-fold iterator variant with non-overlapping labels. | ||
""" | ||
|
||
def __init__(self, labels, p): | ||
|
@@ -1079,11 +1203,11 @@ def cross_val_predict(estimator, X, y=None, cv=None, n_jobs=1, | |
supervised learning. | ||
|
||
cv : integer or cross-validation generator, optional, default=3 | ||
A cross-validation generator to use. If int, determines the number | ||
of folds in StratifiedKFold if estimator is a classifier and the | ||
target y is binary or multiclass, or the number of folds in KFold | ||
A cross-validation generator to use. If int, determines the number | ||
of folds in StratifiedKFold if estimator is a classifier and the | ||
target y is binary or multiclass, or the number of folds in KFold | ||
otherwise. | ||
Specific cross-validation objects can be passed, see | ||
Specific cross-validation objects can be passed, see | ||
sklearn.cross_validation module for the list of possible objects. | ||
This generator must include all elements in the test set exactly once. | ||
Otherwise, a ValueError is raised. | ||
|
@@ -1248,11 +1372,11 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, | |
``scorer(estimator, X, y)``. | ||
|
||
cv : integer or cross-validation generator, optional, default=3 | ||
A cross-validation generator to use. If int, determines the number | ||
of folds in StratifiedKFold if estimator is a classifier and the | ||
target y is binary or multiclass, or the number of folds in KFold | ||
A cross-validation generator to use. If int, determines the number | ||
of folds in StratifiedKFold if estimator is a classifier and the | ||
target y is binary or multiclass, or the number of folds in KFold | ||
otherwise. | ||
Specific cross-validation objects can be passed, see | ||
Specific cross-validation objects can be passed, see | ||
sklearn.cross_validation module for the list of possible objects. | ||
|
||
n_jobs : integer, optional | ||
|
@@ -1569,11 +1693,11 @@ def permutation_test_score(estimator, X, y, cv=None, | |
``scorer(estimator, X, y)``. | ||
|
||
cv : integer or cross-validation generator, optional, default=3 | ||
A cross-validation generator to use. If int, determines the number | ||
of folds in StratifiedKFold if estimator is a classifier and the | ||
target y is binary or multiclass, or the number of folds in KFold | ||
A cross-validation generator to use. If int, determines the number | ||
of folds in StratifiedKFold if estimator is a classifier and the | ||
target y is binary or multiclass, or the number of folds in KFold | ||
otherwise. | ||
Specific cross-validation objects can be passed, see | ||
Specific cross-validation objects can be passed, see | ||
sklearn.cross_validation module for the list of possible objects. | ||
|
||
n_permutations : integer, optional | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean in decending order?