Skip to content

Commit 37ecdd7

Browse files
committed
Changed SubjectIndependentKFold to DisjointGroupKFold
cosmetic changes test (fix seed correctly, use assert_equal for meaningful error messages)
1 parent 53063ba commit 37ecdd7

File tree

2 files changed

+49
-38
lines changed

2 files changed

+49
-38
lines changed

sklearn/cross_validation.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,12 @@ def __len__(self):
332332
return self.n_folds
333333

334334

335-
def subject_independent_folds(subjects, n_folds=3):
336-
""" Creates folds where a same subject is not in two different folds
335+
def disjoint_group_folds(groups, n_folds=3):
336+
"""Creates folds where a same group is not in two different folds
337337
338338
Parameters
339339
----------
340-
subjects: iterable of shape (n_samples, )
340+
groups: iterable of shape (n_samples, )
341341
contains an id for each sample
342342
The folds are built so that the same id doesn't appear in two different folds
343343
@@ -352,34 +352,46 @@ def subject_independent_folds(subjects, n_folds=3):
352352
353353
Notes
354354
-----
355-
The folds are built by distributing the subjects by frequency of appearance.
355+
The folds are built by distributing the groups by frequency of appearance.
356356
"""
357-
subjects = np.array(subjects)
358-
unique_subjects = np.unique(subjects)
357+
groups = np.array(groups)
358+
unique_groups = np.unique(groups)
359359

360-
# number of occurrence of each subject (its "weight")
361-
weight_per_subject = sorted([(sum(subjects == i), i) for i in unique_subjects])
360+
# number of occurrence of each group (its "weight")
361+
weight_per_group = sorted([(sum(groups == group_id), group_id) for group_id in unique_groups])
362362
# Total weight of each fold
363363
weight_per_fold = np.zeros(n_folds)
364364
# For each sample, a digit between 0 and (n_folds - 1) to tell which fold it belongs to
365-
folds = np.zeros(len(subjects))
365+
folds = np.zeros(len(groups))
366366

367367
# While there are weights, distribute them
368368
# Specifically, add the biggest weight to the lightest fold
369-
while weight_per_subject:
369+
while weight_per_group:
370370
ind_min = np.argmin(weight_per_fold)
371-
w, actor = weight_per_subject.pop()
371+
w, group_id = weight_per_group.pop()
372372
weight_per_fold[ind_min] += w
373-
folds[subjects == actor] = ind_min
373+
folds[groups == group_id] = ind_min
374374

375375
return folds
376376

377377

378-
class SubjectIndependentKfold(_BaseKFold):
379-
def __init__(self, subjects, n_folds=3):
378+
class DisjointGroupKfold(_BaseKFold):
379+
def __init__(self, groups, n_folds=3):
380+
"""Creates K approximately equilibrated folds
381+
where the same group will not appear in two different folds
382+
383+
Parameters
384+
----------
385+
groups: numpy array of shape (n_samples, )
386+
contains an id for each sample
387+
The folds are built so that the same id doesn't appear in two different folds
388+
389+
n_folds: int, default is 3
390+
number of folds
391+
"""
380392
self.n_folds = n_folds
381-
self.n = len(subjects)
382-
self.idxs = subject_independent_folds(subjects=subjects, n_folds=n_folds)
393+
self.n = len(groups)
394+
self.idxs = disjoint_group_folds(groups=groups, n_folds=n_folds)
383395

384396
def _iter_test_indices(self):
385397
for i in range(self.n_folds):

sklearn/tests/test_cross_validation.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.utils.testing import assert_almost_equal
1313
from sklearn.utils.testing import assert_raises
1414
from sklearn.utils.testing import assert_greater
15+
from sklearn.utils.testing import assert_greater_equal
1516
from sklearn.utils.testing import assert_less
1617
from sklearn.utils.testing import assert_not_equal
1718
from sklearn.utils.testing import assert_array_almost_equal
@@ -1042,55 +1043,53 @@ def test_check_is_partition():
10421043

10431044
def test_subject_independent_folds():
10441045
""" Check that the function produces equilibrated folds
1045-
with no subject appearing in two different folds
1046+
with no group appearing in two different folds
10461047
"""
10471048
# Fix the seed for reproducibility
1048-
np.random.seed(0)
1049+
rng = np.random.RandomState(0)
10491050

10501051
# Parameters of the test
1051-
n_subjects = 15
1052+
n_groups = 15
10521053
n_samples = 1000
10531054
n_folds = 5
10541055

10551056
# Construct the test data
10561057
tolerance = 0.05 * n_samples # 5 percent error allowed
1057-
subjects = np.random.randint(0, n_subjects, n_samples)
1058-
folds = cval.subject_independent_folds(subjects, n_folds)
1059-
ideal_n_subjects_per_fold = n_samples // n_folds
1058+
groups = np.random.randint(0, n_groups, n_samples)
1059+
folds = cval.disjoint_group_folds(groups, n_folds)
1060+
ideal_n_groups_per_fold = n_samples // n_folds
10601061

10611062
# Check that folds have approximately the same size
1062-
assert(len(folds)==len(subjects))
1063+
assert_equal(len(folds), len(groups))
10631064
for i in np.unique(folds):
1064-
assert(abs(sum(folds == i) - ideal_n_subjects_per_fold) <= tolerance)
1065+
assert_greater_equal(tolerance, abs(sum(folds == i) - ideal_n_groups_per_fold))
10651066

10661067
# Check that each subjects appears only in 1 fold
1067-
for subject in np.unique(subjects):
1068-
assert(len(np.unique(folds[subjects == subject])) == 1)
1068+
for group in np.unique(groups):
1069+
assert_equal(len(np.unique(folds[groups == group])), 1)
10691070

1070-
subjects = ['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean',
1071+
# Construct the test data
1072+
groups = ['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean',
10711073
'Francis', 'Robert', 'Michel', 'Rachel', 'Lois',
10721074
'Michelle', 'Bernard', 'Marion', 'Laura', 'Jean',
10731075
'Rachel', 'Franck', 'John', 'Gael', 'Anna', 'Alix',
10741076
'Robert', 'Marion', 'David', 'Tony', 'Abel', 'Becky',
10751077
'Madmood', 'Cary', 'Mary', 'Alexandre', 'David', 'Francis',
10761078
'Barack', 'Abdoul', 'Rasha', 'Xi', 'Silvia']
10771079

1078-
n_subjects = len(np.unique(subjects))
1079-
n_samples = len(subjects)
1080+
n_groups = len(np.unique(groups))
1081+
n_samples = len(groups)
10801082
n_folds = 5
1081-
1082-
# Construct the test data
10831083
tolerance = 0.05 * n_samples # 5 percent error allowed
1084-
subjects = np.random.randint(0, n_subjects, n_samples)
1085-
folds = cval.subject_independent_folds(subjects, n_folds)
1086-
ideal_n_subjects_per_fold = n_samples // n_folds
1084+
folds = cval.disjoint_group_folds(groups, n_folds)
1085+
ideal_n_groups_per_fold = n_samples // n_folds
10871086

10881087
# Check that folds have approximately the same size
1089-
assert(len(folds)==len(subjects))
1088+
assert_equal(len(folds), len(groups))
10901089
for i in np.unique(folds):
1091-
assert(abs(sum(folds == i) - ideal_n_subjects_per_fold) <= tolerance)
1090+
assert_greater_equal(tolerance, abs(sum(folds == i) - ideal_n_groups_per_fold))
10921091

10931092
# Check that each subjects appears only in 1 fold
1094-
for subject in np.unique(subjects):
1095-
assert(len(np.unique(folds[subjects == subject])) == 1)
1093+
for group in np.unique(groups):
1094+
assert_equal(len(np.unique(folds[groups == group])), 1)
10961095

0 commit comments

Comments
 (0)