Skip to content

[MRG] add stratify and shuffle variants for GroupKFold #9413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,12 @@ parameter.
Group k-fold
------------

class:GroupKFold is a variation of k-fold which ensures that the same group is
not represented in both testing and training sets. For example if the data is
obtained from different subjects with several samples per-subject and if the
:class:`GroupKFold` is a variation of k-fold which ensures that the same group
is not represented in both testing and training sets. For example if the data
is obtained from different subjects with several samples per-subject and if the
model is flexible enough to learn from highly person specific features it
could fail to generalize to new subjects. class:GroupKFold makes it possible
to detect this kind of overfitting situations.
could fail to generalize to new subjects. :class:`GroupKFold` makes it possible
to avoid these kind of overfitting situations.

Imagine you have three subjects, each with an associated number from 1 to 3::

Expand All @@ -558,6 +558,33 @@ Each subject is in a different testing fold, and the same subject is never in
both testing and training. Notice that the folds do not have exactly the same
size due to the imbalance in the data.

The same group will not appear in two different folds;
this is a hard constraint. After this constraint is enforced,
there are still multiple ways to divide groups across folds. A greedy strategy
is used to create folds of approximately the same size: at each step, the fold
with the least number of items is assigned a new group. The order in which
groups are assigned can be used to tweak the distribution of the resulting
folds.

The default, ``method='balance'``, will try to balance the sizes of the folds,
by assigning the largest groups first. With ``method='stratify_median'`` or
``method='stratify_mode``, items are spread across the folds by stratifying on
the ``y`` variable, as far as possible. Median should be used for continuous
variables, and mode for discrete variables. Stratification may be important
when the ``y`` variable has a skewed distribution; stratification can help
ensure that rare ``y`` values are represented in each fold.
Finally, ``method='shuffle'`` adds randomness by shuffling the groups. This
strategy is useful when you want to generate multiple sets of folds; repeated
use of the other methods would deterministically result in the same folds.

The latter two options work best when groups are relatively small (i.e., there
are many groups), to avoid folds of uneven sizes. The stratification relies on
picking ``y``-values of each group that are representative of its group.

.. topic:: Examples

* :ref:`sphx_glr_auto_examples_model_selection_plot_groupkfold.py`,


Leave One Group Out
-------------------
Expand Down
33 changes: 33 additions & 0 deletions examples/model_selection/plot_groupkfold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
====================
Group K-Fold methods
====================

This example demonstrates when the stratify option of GroupKFold has an
advantage.
"""
from matplotlib import pyplot as plt
import numpy as np
from sklearn.model_selection import GroupKFold

print(__doc__)

rng = np.random.RandomState(0)
n_samples = 1000
n_groups = 100
n_folds = 2
X = np.arange(n_samples)
# Sort data points to highlight the effect of stratification
y = np.sort(rng.normal(size=n_samples))
groups = np.sort(rng.randint(0, n_groups, n_samples))

fig, axes = plt.subplots(1, 3, figsize=(18, 4), sharex=True, sharey=True)
for n, method in enumerate(('balance', 'stratify_median', 'shuffle')):
cv = GroupKFold(n_folds, method=method)
for m, (train, test) in enumerate(cv.split(X, y, groups)):
axes[n].hist(y[test], bins=20, histtype='step',
label='fold %d' % (m + 1))
print('%s fold %d: %d items' % (method, m + 1, len(test)))
axes[n].set_xlabel(method)
axes[n].legend(loc='upper right')
plt.show()
61 changes: 50 additions & 11 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,25 +440,31 @@ class GroupKFold(_BaseKFold):
The same group will not appear in two different folds (the number of
distinct groups has to be at least equal to the number of folds).

The folds are approximately balanced in the sense that the number of
distinct groups is approximately the same in each fold.

Parameters
----------
n_splits : int, default=3
Number of folds. Must be at least 2.
method: string, default='balance'
One of 'balance', 'stratify_median', 'stratify_mode', 'shuffle'.
By default, try to equalize the sizes of the resulting folds.
If 'stratify_median', distribute groups evenly across folds according
to their median ``y`` values; use when ``y`` is continuous.
If 'stratify_mode', distribute groups evenly across folds according to
the mode of their ``y`` values; use when ``y`` is discrete.
If 'shuffle', shuffle the groups to randomize their assignments to
folds.

Examples
--------
>>> from sklearn.model_selection import GroupKFold
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> y = np.array([1, 2, 3, 4])
>>> groups = np.array([0, 0, 2, 2])
>>> group_kfold = GroupKFold(n_splits=2)
>>> group_kfold = GroupKFold(method='balance', n_splits=2)
>>> group_kfold.get_n_splits(X, y, groups)
2
>>> print(group_kfold)
GroupKFold(n_splits=2)
GroupKFold(method='balance', n_splits=2)
>>> for train_index, test_index in group_kfold.split(X, y, groups):
... print("TRAIN:", train_index, "TEST:", test_index)
... X_train, X_test = X[train_index], X[test_index]
Expand All @@ -480,7 +486,13 @@ class GroupKFold(_BaseKFold):
For splitting the data according to explicit domain-specific
stratification of the dataset.
"""
def __init__(self, n_splits=3):
def __init__(self, n_splits=3, method='balance'):
if method not in ('balance', 'stratify_median', 'stratify_mode',
'shuffle'):
raise ValueError("The 'method' parameter should be one of: "
"'balance', 'stratify_median', 'stratify_mode', "
"'shuffle'")
self.method = method
super(GroupKFold, self).__init__(n_splits, shuffle=False,
random_state=None)

Expand All @@ -489,7 +501,8 @@ def _iter_test_indices(self, X, y, groups):
raise ValueError("The 'groups' parameter should not be None.")
groups = check_array(groups, ensure_2d=False, dtype=None)

unique_groups, groups = np.unique(groups, return_inverse=True)
unique_groups, unique_indices, groups = np.unique(
groups, return_index=True, return_inverse=True)
n_groups = len(unique_groups)

if self.n_splits > n_groups:
Expand All @@ -500,17 +513,43 @@ def _iter_test_indices(self, X, y, groups):
# Weight groups by their number of occurrences
n_samples_per_group = np.bincount(groups)

# Distribute the most frequent groups first
indices = np.argsort(n_samples_per_group)[::-1]
if self.method == 'balance':
# Distribute the most frequent groups first
indices = np.argsort(n_samples_per_group)[::-1]
elif self.method.startswith('stratify_'):
# Distribute according to median y value per group
if y is None:
raise ValueError("The 'y' parameter should not be None.")
y = check_array(y, ensure_2d=False, dtype=None)
y_by_group = dict.fromkeys(unique_groups, [])
for group, y_value in zip(groups, y):
y_by_group[group].append(y_value)
if self.method == 'stratify_median':
# manual median; np.median doesn't work when groups are strings
by_group = [
sorted(y_by_group[group])[len(y_by_group[group]) // 2]
for group in unique_groups]
elif self.method == 'stratify_mode':
def mode(x):
values, counts = np.unique(x, return_counts=True)
return values[np.argmax(counts)]

by_group = [mode(y_by_group[group]) for group in unique_groups]
indices = np.argsort(by_group)
elif self.method == 'shuffle':
# Shuffle the groups
rng = check_random_state(self.random_state)
indices = np.arange(n_groups)
rng.shuffle(indices)
n_samples_per_group = n_samples_per_group[indices]

# Total weight of each fold
n_samples_per_fold = np.zeros(self.n_splits)

# Mapping from group index to fold index
group_to_fold = np.zeros(len(unique_groups))
group_to_fold = np.zeros(n_groups)

# Distribute samples by adding the largest weight to the lightest fold
# Distribute samples by adding groups to the lightest fold
for group_index, weight in enumerate(n_samples_per_group):
lightest_fold = np.argmin(n_samples_per_fold)
n_samples_per_fold[lightest_fold] += weight
Expand Down
43 changes: 23 additions & 20 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,28 +1220,31 @@ def test_group_kfold():

ideal_n_groups_per_fold = n_samples // n_splits

len(np.unique(groups))
# Get the test fold indices from the test set indices of each fold
folds = np.zeros(n_samples)
lkf = GroupKFold(n_splits=n_splits)
for i, (_, test) in enumerate(lkf.split(X, y, groups)):
folds[test] = i

# Check that folds have approximately the same size
assert_equal(len(folds), len(groups))
for i in np.unique(folds):
assert_greater_equal(tolerance,
abs(sum(folds == i) - ideal_n_groups_per_fold))

# Check that each group appears only in 1 fold
for group in np.unique(groups):
assert_equal(len(np.unique(folds[groups == group])), 1)
for method in ('balance', 'stratify_median', 'stratify_mode', 'shuffle'):
# Get the test fold indices from the test set indices of each fold
folds = np.zeros(n_samples)
lkf = GroupKFold(n_splits=n_splits, method=method)
for i, (_, test) in enumerate(lkf.split(X, y, groups)):
folds[test] = i

# Check that folds have approximately the same size
if method == 'balance':
assert_equal(len(folds), len(groups))
for i in np.unique(folds):
assert_greater_equal(
tolerance,
abs(sum(folds == i) - ideal_n_groups_per_fold))

# Check that each group appears only in 1 fold
for group in np.unique(groups):
assert_equal(len(np.unique(folds[groups == group])), 1)

# Check that no group is on both sides of the split
groups = np.asarray(groups, dtype=object)
for train, test in lkf.split(X, y, groups):
assert_equal(len(np.intersect1d(groups[train], groups[test])), 0)
# Check that no group is on both sides of the split
groups = np.asarray(groups, dtype=object)
for train, test in lkf.split(X, y, groups):
assert_equal(len(np.intersect1d(groups[train], groups[test])), 0)

lkf = GroupKFold(n_splits=n_splits, method='balance')
# Construct the test data
groups = np.array(['Albert', 'Jean', 'Bertrand', 'Michel', 'Jean',
'Francis', 'Robert', 'Michel', 'Rachel', 'Lois',
Expand Down