diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 4309c8a21dfc8..6147086d2bb4e 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -168,6 +168,7 @@ Classes cross_validation.LeavePOut cross_validation.PredefinedSplit cross_validation.StratifiedKFold + cross_validation.DisjointLabelKFold cross_validation.ShuffleSplit cross_validation.StratifiedShuffleSplit diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 0aa6bf1e3b692..5d6dd2d041a6f 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -261,6 +261,33 @@ two slightly unbalanced classes:: [0 1 2 4 5 6 7] [3 8 9] +Disjoint label KFold +-------------------- + +:class:`DisjointLabelKFold` is a variation of *k-fold* which ensures that the same +label is not in both testing and training sets. +This is necessary for example if you obtained data from different subjects and you +want to avoid over-fitting (ie learning person specific features) by testing and +training on different subjects. + +Imagine you have three subjects, each with an associated number from 1 to 3:: + + >>> from sklearn.cross_validation import DisjointLabelKFold + + >>> labels = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3] + + >>> dlkf = DisjointLabelKFold(labels, 3) + >>> for train, test in dlkf: + ... print("%s %s" % (train, test)) + [0 1 2 3 4 5] [6 7 8 9] + [0 1 2 6 7 8 9] [3 4 5] + [3 4 5 6 7 8 9] [0 1 2] + +Each subject is in a different testing fold, and the same subject is never in both +testing and training. +Notice that the folds do not have exactly the same size due to the imbalance in the data. + + Leave-One-Out - LOO ------------------- diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index dacf4c4c67f63..ed7458be80bb1 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -297,6 +297,8 @@ class KFold(_BaseKFold): StratifiedKFold: take label information into account to avoid building folds with imbalanced class distributions (for binary or multiclass classification tasks). + + DisjointLabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, n, n_folds=3, shuffle=False, @@ -332,6 +334,133 @@ def __len__(self): return self.n_folds +def disjoint_label_folds(labels, n_folds=3): + """Creates folds where a same label is not in two different folds. + + Parameters + ---------- + labels: numpy array, shape (n_samples,) + Contains an id for each sample. + The folds are built so that the same id doesn't appear in two different folds. + + n_folds: int, default=3 + Number of folds to split the data into. + + Returns + ------- + folds: numpy array of shape (n_samples, ) + Array of integers between 0 and (n_folds - 1). + Folds[i] contains the folds to which sample i is assigned. + + Notes + ----- + The folds are built by distributing the labels by frequency of appearance. + The number of labels has to be at least equal to the number of folds. + """ + labels = np.array(labels) + unique_labels, labels = np.unique(labels, return_inverse=True) + n_labels = len(unique_labels) + if n_folds > n_labels: + raise ValueError( + ("Cannot have number of folds n_folds={0} greater" + " than the number of labels: {1}.").format(n_folds, n_labels)) + + # number of occurrence of each label (its "weight") + samples_per_label = np.bincount(labels) + # We want to distribute the most frequent labels first + ind = np.argsort(samples_per_label, kind='mergesort')[::-1] + samples_per_label = samples_per_label[ind] + + # Total weight of each fold + samples_per_fold = np.zeros(n_folds, dtype=np.uint64) + + # Mapping from label index to fold index + label_to_fold = np.zeros(len(unique_labels), dtype=np.uintp) + + # While there are weights, distribute them + # Specifically, add the biggest weight to the lightest fold + for label_index, w in enumerate(samples_per_label): + ind_min = np.argmin(samples_per_fold) + samples_per_fold[ind_min] += w + label_to_fold[ind[label_index]] = ind_min + + folds = label_to_fold[labels] + + return folds + + +class DisjointLabelKFold(_BaseKFold): + """K-fold iterator variant with non-overlapping labels. + + The same label will not appear in two different folds (the number of + labels has to be at least equal to the number of folds). + + The folds are approximately balanced in the sense so that the number of + distinct labels is approximately the same in each fold. + + Parameters + ---------- + labels : array-like with shape (n_samples, ) + Contains a label for each sample. + The folds are built so that the same label doesn't appear in two different folds. + + n_folds : int, default is 3 + Number of folds. + + Examples + -------- + >>> from sklearn import cross_validation + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([1, 2, 3, 4]) + >>> labels = np.array([0, 0, 2, 2]) + >>> dl_kfold = cross_validation.DisjointLabelKFold(labels, n_folds=2) + >>> len(dl_kfold) + 2 + >>> print(dl_kfold) + sklearn.cross_validation.DisjointLabelKFold(n_labels=4, n_folds=2) + >>> for train_index, test_index in dl_kfold: + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... print(X_train, X_test, y_train, y_test) + ... + TRAIN: [0 1] TEST: [2 3] + [[1 2] + [3 4]] [[5 6] + [7 8]] [1 2] [3 4] + TRAIN: [2 3] TEST: [0 1] + [[5 6] + [7 8]] [[1 2] + [3 4]] [3 4] [1 2] + + See also + -------- + LeaveOneLabelOut for splitting the data according to explicit, + domain-specific stratification of the dataset. + """ + def __init__(self, labels, n_folds=3): + # No shuffling implemented yet + super(DisjointLabelKFold, self).__init__(len(labels), n_folds, False, None) + self.n_folds = n_folds + self.n = len(labels) + self.idxs = disjoint_label_folds(labels=labels, n_folds=n_folds) + + def _iter_test_indices(self): + for i in range(self.n_folds): + yield (self.idxs == i) + + def __repr__(self): + return '{0}.{1}(n_labels={2}, n_folds={3})'.format( + self.__class__.__module__, + self.__class__.__name__, + self.n, + self.n_folds, + ) + + def __len__(self): + return self.n_folds + + class StratifiedKFold(_BaseKFold): """Stratified K-Folds cross validation iterator @@ -380,6 +509,9 @@ class StratifiedKFold(_BaseKFold): All the folds have size trunc(n_samples / n_folds), the last one has the complementary. + See also + -------- + DisjointLabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, y, n_folds=3, shuffle=False, @@ -486,6 +618,9 @@ class LeaveOneLabelOut(_PartitionIterator): [3 4]] [[5 6] [7 8]] [1 2] [1 2] + See also + -------- + DisjointLabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, labels): @@ -559,6 +694,10 @@ class LeavePLabelOut(_PartitionIterator): TRAIN: [0] TEST: [1 2] [[1 2]] [[3 4] [5 6]] [1] [2 1] + + See also + -------- + DisjointLabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, labels, p): diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 786bf561ec5e2..b7e4e15ef76eb 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -12,6 +12,7 @@ from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_array_almost_equal @@ -1038,3 +1039,70 @@ def test_check_is_partition(): p[0] = 23 assert_false(cval._check_is_partition(p, 100)) + + +def test_disjoint_label_folds(): + ## Check that the function produces equilibrated folds + ## with no label appearing in two different folds + + # Fix the seed for reproducibility + rng = np.random.RandomState(0) + + # Parameters of the test + n_labels = 15 + n_samples = 1000 + n_folds = 5 + + # Construct the test data + tolerance = 0.05 * n_samples # 5 percent error allowed + labels = rng.randint(0, n_labels, n_samples) + folds = cval.disjoint_label_folds(labels, n_folds) + ideal_n_labels_per_fold = n_samples // n_folds + + # Check that folds have approximately the same size + assert_equal(len(folds), len(labels)) + for i in np.unique(folds): + assert_greater_equal(tolerance, abs(sum(folds == i) - ideal_n_labels_per_fold)) + + # Check that each label appears only in 1 fold + for label in np.unique(labels): + assert_equal(len(np.unique(folds[labels == label])), 1) + + # Check that no label is on both sides of the split + labels = np.asarray(labels, dtype=object) # to allow fancy indexing on labels + for train, test in cval.DisjointLabelKFold(labels, n_folds=n_folds): + assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) + + # Construct the test data + labels = ['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean', + 'Francis', 'Robert', 'Michel', 'Rachel', 'Lois', + 'Michelle', 'Bernard', 'Marion', 'Laura', 'Jean', + 'Rachel', 'Franck', 'John', 'Gael', 'Anna', 'Alix', + 'Robert', 'Marion', 'David', 'Tony', 'Abel', 'Becky', + 'Madmood', 'Cary', 'Mary', 'Alexandre', 'David', 'Francis', + 'Barack', 'Abdoul', 'Rasha', 'Xi', 'Silvia'] + + n_labels = len(np.unique(labels)) + n_samples = len(labels) + n_folds = 5 + tolerance = 0.05 * n_samples # 5 percent error allowed + folds = cval.disjoint_label_folds(labels, n_folds) + ideal_n_labels_per_fold = n_samples // n_folds + + # Check that folds have approximately the same size + assert_equal(len(folds), len(labels)) + for i in np.unique(folds): + assert_greater_equal(tolerance, abs(sum(folds == i) - ideal_n_labels_per_fold)) + + # Check that each label appears only in 1 fold + for label in np.unique(labels): + assert_equal(len(np.unique(folds[labels == label])), 1) + + # Check that no label is on both sides of the split + labels = np.asarray(labels, dtype=object) # to allow fancy indexing on labels + for train, test in cval.DisjointLabelKFold(labels, n_folds=n_folds): + assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) + + # Should fail if there are more folds than labels + labels = np.array([1, 1, 1, 2, 2]) + assert_raises(ValueError, cval.disjoint_label_folds, labels, n_folds=3)