diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index d8b3dd4dbe1d6..6001c3d9f92a0 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -162,16 +162,18 @@ Classes :template: class.rst cross_validation.KFold + cross_validation.LabelKFold + cross_validation.LabelShuffleSplit cross_validation.LeaveOneLabelOut cross_validation.LeaveOneOut cross_validation.LeavePLabelOut cross_validation.LeavePOut cross_validation.PredefinedSplit - cross_validation.StratifiedKFold cross_validation.ShuffleSplit - cross_validation.LabelShuffleSplit + cross_validation.StratifiedKFold cross_validation.StratifiedShuffleSplit + .. autosummary:: :toctree: generated/ :template: function.rst diff --git a/doc/modules/cross_validation.rst b/doc/modules/cross_validation.rst index 53afdf53550b1..e76b95c6e48be 100644 --- a/doc/modules/cross_validation.rst +++ b/doc/modules/cross_validation.rst @@ -261,6 +261,33 @@ two slightly unbalanced classes:: [0 1 2 4 5 6 7] [3 8 9] +Label k-fold +------------ + +:class:`LabelKFold` is a variation of *k-fold* which ensures that the same +label is not in both testing and training sets. This is necessary for example +if you obtained data from different subjects and you want to avoid over-fitting +(i.e., learning person specific features) by testing and training on different +subjects. + +Imagine you have three subjects, each with an associated number from 1 to 3:: + + >>> from sklearn.cross_validation import LabelKFold + + >>> labels = [1, 1, 1, 2, 2, 2, 3, 3, 3, 3] + + >>> lkf = LabelKFold(labels, n_folds=3) + >>> for train, test in lkf: + ... print("%s %s" % (train, test)) + [0 1 2 3 4 5] [6 7 8 9] + [0 1 2 6 7 8 9] [3 4 5] + [3 4 5 6 7 8 9] [0 1 2] + +Each subject is in a different testing fold, and the same subject is never in +both testing and training. Notice that the folds do not have exactly the same +size due to the imbalance in the data. + + Leave-One-Out - LOO ------------------- @@ -435,15 +462,15 @@ Label-Shuffle-Split :class:`LabelShuffleSplit` -The :class:`LabelShuffleSplit` iterator behaves as a combination of -:class:`ShuffleSplit` and :class:`LeavePLabelsOut`, and generates a +The :class:`LabelShuffleSplit` iterator behaves as a combination of +:class:`ShuffleSplit` and :class:`LeavePLabelsOut`, and generates a sequence of randomized partitions in which a subset of labels are held out for each split. Here is a usage example:: >>> from sklearn.cross_validation import LabelShuffleSplit - + >>> labels = [1, 1, 2, 2, 3, 3, 4, 4] >>> slo = LabelShuffleSplit(labels, n_iter=4, test_size=0.5, ... random_state=0) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 7ba8e03378b11..66df0155d76f4 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -34,9 +34,12 @@ New features function into a ``Pipeline``-compatible transformer object. By Joe Jevnik. - - :class:`cross_validation.LabelShuffleSplit` generates random train-test - splits, similar to :class:`cross_validation.ShuffleSplit`, except that - the splits are conditioned on a label array. By `Brian McFee`_. + - The new classes :class:`cross_validation.LabelKFold` and + :class:`cross_validation.LabelShuffleSplit` generate train-test folds, + respectively similar to :class:`cross_validation.KFold` and + :class:`cross_validation.ShuffleSplit`, except that the folds are + conditioned on a label array. By `Brian McFee`_, Jean Kossaifi and + `Gilles Louppe`_. Enhancements @@ -127,11 +130,11 @@ Enhancements - Allow :func:`datasets.make_multilabel_classification` to output a sparse ``y``. By Kashif Rasul. - + - :class:`cluster.DBSCAN` now accepts a sparse matrix of precomputed distances, allowing memory-efficient distance precomputation. By `Joel Nothman`_. - + - :class:`tree.DecisionTreeClassifier` now exposes an ``apply`` method for retrieving the leaf indices samples are predicted as. By `Daniel Galvez`_ and `Gilles Louppe`_. diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 4c49bee86517d..117327dbc814a 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -33,6 +33,7 @@ from .utils.fixes import bincount __all__ = ['KFold', + 'LabelKFold', 'LeaveOneLabelOut', 'LeaveOneOut', 'LeavePLabelOut', @@ -273,15 +274,15 @@ class KFold(_BaseKFold): Whether to shuffle the data before splitting into batches. random_state : None, int or RandomState - When shuffle=True, pseudo-random number generator state used for + When shuffle=True, pseudo-random number generator state used for shuffling. If None, use default numpy RNG for shuffling. Examples -------- - >>> from sklearn import cross_validation + >>> from sklearn.cross_validation import KFold >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) >>> y = np.array([1, 2, 3, 4]) - >>> kf = cross_validation.KFold(4, n_folds=2) + >>> kf = KFold(4, n_folds=2) >>> len(kf) 2 >>> print(kf) # doctest: +NORMALIZE_WHITESPACE @@ -304,6 +305,8 @@ class KFold(_BaseKFold): StratifiedKFold: take label information into account to avoid building folds with imbalanced class distributions (for binary or multiclass classification tasks). + + LabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, n, n_folds=3, shuffle=False, @@ -339,6 +342,117 @@ def __len__(self): return self.n_folds +class LabelKFold(_BaseKFold): + """K-fold iterator variant with non-overlapping labels. + + The same label will not appear in two different folds (the number of + distinct labels has to be at least equal to the number of folds). + + The folds are approximately balanced in the sense that the number of + distinct labels is approximately the same in each fold. + + Parameters + ---------- + labels : array-like with shape (n_samples, ) + Contains a label for each sample. + The folds are built so that the same label does not appear in two + different folds. + + n_folds : int, default=3 + Number of folds. Must be at least 2. + + shuffle : boolean, optional + Whether to shuffle the data before splitting into batches. + + random_state : None, int or RandomState + When shuffle=True, pseudo-random number generator state used for + shuffling. If None, use default numpy RNG for shuffling. + + Examples + -------- + >>> from sklearn.cross_validation import LabelKFold + >>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + >>> y = np.array([1, 2, 3, 4]) + >>> labels = np.array([0, 0, 2, 2]) + >>> label_kfold = LabelKFold(labels, n_folds=2) + >>> len(label_kfold) + 2 + >>> print(label_kfold) + sklearn.cross_validation.LabelKFold(n_labels=4, n_folds=2) + >>> for train_index, test_index in label_kfold: + ... print("TRAIN:", train_index, "TEST:", test_index) + ... X_train, X_test = X[train_index], X[test_index] + ... y_train, y_test = y[train_index], y[test_index] + ... print(X_train, X_test, y_train, y_test) + ... + TRAIN: [0 1] TEST: [2 3] + [[1 2] + [3 4]] [[5 6] + [7 8]] [1 2] [3 4] + TRAIN: [2 3] TEST: [0 1] + [[5 6] + [7 8]] [[1 2] + [3 4]] [3 4] [1 2] + + See also + -------- + LeaveOneLabelOut for splitting the data according to explicit, + domain-specific stratification of the dataset. + """ + def __init__(self, labels, n_folds=3, shuffle=False, random_state=None): + super(LabelKFold, self).__init__(len(labels), n_folds, shuffle, + random_state) + + unique_labels, labels = np.unique(labels, return_inverse=True) + n_labels = len(unique_labels) + + if n_folds > n_labels: + raise ValueError( + ("Cannot have number of folds n_folds={0} greater" + " than the number of labels: {1}.").format(n_folds, + n_labels)) + + # Weight labels by their number of occurences + n_samples_per_label = np.bincount(labels) + + # Distribute the most frequent labels first + indices = np.argsort(n_samples_per_label)[::-1] + n_samples_per_label = n_samples_per_label[indices] + + # Total weight of each fold + n_samples_per_fold = np.zeros(n_folds) + + # Mapping from label index to fold index + label_to_fold = np.zeros(len(unique_labels)) + + # Distribute samples by adding the largest weight to the lightest fold + for label_index, weight in enumerate(n_samples_per_label): + lightest_fold = np.argmin(n_samples_per_fold) + n_samples_per_fold[lightest_fold] += weight + label_to_fold[indices[label_index]] = lightest_fold + + self.idxs = label_to_fold[labels] + + if shuffle: + rng = check_random_state(self.random_state) + rng.shuffle(self.idxs) + + def _iter_test_indices(self): + for i in range(self.n_folds): + yield (self.idxs == i) + + def __repr__(self): + return '{0}.{1}(n_labels={2}, n_folds={3})'.format( + self.__class__.__module__, + self.__class__.__name__, + self.n, + self.n_folds, + ) + + def __len__(self): + return self.n_folds + + class StratifiedKFold(_BaseKFold): """Stratified K-Folds cross validation iterator @@ -363,15 +477,15 @@ class StratifiedKFold(_BaseKFold): into batches. random_state : None, int or RandomState - When shuffle=True, pseudo-random number generator state used for + When shuffle=True, pseudo-random number generator state used for shuffling. If None, use default numpy RNG for shuffling. Examples -------- - >>> from sklearn import cross_validation + >>> from sklearn.cross_validation import StratifiedKFold >>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) >>> y = np.array([0, 0, 1, 1]) - >>> skf = cross_validation.StratifiedKFold(y, n_folds=2) + >>> skf = StratifiedKFold(y, n_folds=2) >>> len(skf) 2 >>> print(skf) # doctest: +NORMALIZE_WHITESPACE @@ -389,6 +503,9 @@ class StratifiedKFold(_BaseKFold): All the folds have size trunc(n_samples / n_folds), the last one has the complementary. + See also + -------- + LabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, y, n_folds=3, shuffle=False, @@ -497,6 +614,9 @@ class LeaveOneLabelOut(_PartitionIterator): [3 4]] [[5 6] [7 8]] [1 2] [1 2] + See also + -------- + LabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, labels): @@ -572,6 +692,10 @@ class LeavePLabelOut(_PartitionIterator): TRAIN: [0] TEST: [1 2] [[1 2]] [[3 4] [5 6]] [1] [2 1] + + See also + -------- + LabelKFold: K-fold iterator variant with non-overlapping labels. """ def __init__(self, labels, p): @@ -1079,11 +1203,11 @@ def cross_val_predict(estimator, X, y=None, cv=None, n_jobs=1, supervised learning. cv : integer or cross-validation generator, optional, default=3 - A cross-validation generator to use. If int, determines the number - of folds in StratifiedKFold if estimator is a classifier and the - target y is binary or multiclass, or the number of folds in KFold + A cross-validation generator to use. If int, determines the number + of folds in StratifiedKFold if estimator is a classifier and the + target y is binary or multiclass, or the number of folds in KFold otherwise. - Specific cross-validation objects can be passed, see + Specific cross-validation objects can be passed, see sklearn.cross_validation module for the list of possible objects. This generator must include all elements in the test set exactly once. Otherwise, a ValueError is raised. @@ -1248,11 +1372,11 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, ``scorer(estimator, X, y)``. cv : integer or cross-validation generator, optional, default=3 - A cross-validation generator to use. If int, determines the number - of folds in StratifiedKFold if estimator is a classifier and the - target y is binary or multiclass, or the number of folds in KFold + A cross-validation generator to use. If int, determines the number + of folds in StratifiedKFold if estimator is a classifier and the + target y is binary or multiclass, or the number of folds in KFold otherwise. - Specific cross-validation objects can be passed, see + Specific cross-validation objects can be passed, see sklearn.cross_validation module for the list of possible objects. n_jobs : integer, optional @@ -1569,11 +1693,11 @@ def permutation_test_score(estimator, X, y, cv=None, ``scorer(estimator, X, y)``. cv : integer or cross-validation generator, optional, default=3 - A cross-validation generator to use. If int, determines the number - of folds in StratifiedKFold if estimator is a classifier and the - target y is binary or multiclass, or the number of folds in KFold + A cross-validation generator to use. If int, determines the number + of folds in StratifiedKFold if estimator is a classifier and the + target y is binary or multiclass, or the number of folds in KFold otherwise. - Specific cross-validation objects can be passed, see + Specific cross-validation objects can be passed, see sklearn.cross_validation module for the list of possible objects. n_permutations : integer, optional diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 3c45261103411..df2d90877bc50 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -13,6 +13,7 @@ from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_raises from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import assert_greater_equal from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_not_equal from sklearn.utils.testing import assert_array_almost_equal @@ -358,6 +359,71 @@ def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372 assert_greater(mean_score, 0.85) +def test_label_kfold(): + rng = np.random.RandomState(0) + + # Parameters of the test + n_labels = 15 + n_samples = 1000 + n_folds = 5 + + # Construct the test data + tolerance = 0.05 * n_samples # 5 percent error allowed + labels = rng.randint(0, n_labels, n_samples) + folds = cval.LabelKFold(labels, n_folds).idxs + ideal_n_labels_per_fold = n_samples // n_folds + + # Check that folds have approximately the same size + assert_equal(len(folds), len(labels)) + for i in np.unique(folds): + assert_greater_equal(tolerance, + abs(sum(folds == i) - ideal_n_labels_per_fold)) + + # Check that each label appears only in 1 fold + for label in np.unique(labels): + assert_equal(len(np.unique(folds[labels == label])), 1) + + # Check that no label is on both sides of the split + labels = np.asarray(labels, dtype=object) + for train, test in cval.LabelKFold(labels, n_folds=n_folds): + assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) + + # Construct the test data + labels = ['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean', + 'Francis', 'Robert', 'Michel', 'Rachel', 'Lois', + 'Michelle', 'Bernard', 'Marion', 'Laura', 'Jean', + 'Rachel', 'Franck', 'John', 'Gael', 'Anna', 'Alix', + 'Robert', 'Marion', 'David', 'Tony', 'Abel', 'Becky', + 'Madmood', 'Cary', 'Mary', 'Alexandre', 'David', 'Francis', + 'Barack', 'Abdoul', 'Rasha', 'Xi', 'Silvia'] + + n_labels = len(np.unique(labels)) + n_samples = len(labels) + n_folds = 5 + tolerance = 0.05 * n_samples # 5 percent error allowed + folds = cval.LabelKFold(labels, n_folds).idxs + ideal_n_labels_per_fold = n_samples // n_folds + + # Check that folds have approximately the same size + assert_equal(len(folds), len(labels)) + for i in np.unique(folds): + assert_greater_equal(tolerance, + abs(sum(folds == i) - ideal_n_labels_per_fold)) + + # Check that each label appears only in 1 fold + for label in np.unique(labels): + assert_equal(len(np.unique(folds[labels == label])), 1) + + # Check that no label is on both sides of the split + labels = np.asarray(labels, dtype=object) + for train, test in cval.LabelKFold(labels, n_folds=n_folds): + assert_equal(len(np.intersect1d(labels[train], labels[test])), 0) + + # Should fail if there are more folds than labels + labels = np.array([1, 1, 1, 2, 2]) + assert_raises(ValueError, cval.LabelKFold, labels, n_folds=3) + + def test_shuffle_split(): ss1 = cval.ShuffleSplit(10, test_size=0.2, random_state=0) ss2 = cval.ShuffleSplit(10, test_size=2, random_state=0)