Skip to content

Commit afc6cc5

Browse files
qinhanmin2014GaelVaroquaux
authored andcommitted
FIX Shuffle each class's samples with different random_state in StratifiedKFold (#13124)
* Enable StratifiedKFold to produce different splits * what's new * redundant statement * update what's new * redundant comment * add a test * move what's new entry * review comment * review comment
1 parent 66899ed commit afc6cc5

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

doc/whats_new/v0.21.rst

+5
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,11 @@ Support for Python 3.4 and below has been officially dropped.
267267
:func:`~model_selection.validation_curve` only the latter is required.
268268
:issue:`12613` and :issue:`12669` by :user:`Marc Torrellas <marctorrellas>`.
269269

270+
- |Fix| Fixed a bug where :class:`model_selection.StratifiedKFold`
271+
shuffles each class's samples with the same ``random_state``,
272+
making ``shuffle=True`` ineffective.
273+
:issue:`13124` by :user:`Hanmin Qin <qinhanmin2014>`.
274+
270275
:mod:`sklearn.neighbors`
271276
........................
272277

sklearn/model_selection/_split.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -576,8 +576,7 @@ class StratifiedKFold(_BaseKFold):
576576
``n_splits`` default value will change from 3 to 5 in v0.22.
577577
578578
shuffle : boolean, optional
579-
Whether to shuffle each stratification of the data before splitting
580-
into batches.
579+
Whether to shuffle each class's samples before splitting into batches.
581580
582581
random_state : int, RandomState instance or None, optional, default=None
583582
If int, random_state is the seed used by the random number generator;
@@ -620,7 +619,7 @@ def __init__(self, n_splits='warn', shuffle=False, random_state=None):
620619
super().__init__(n_splits, shuffle, random_state)
621620

622621
def _make_test_folds(self, X, y=None):
623-
rng = self.random_state
622+
rng = check_random_state(self.random_state)
624623
y = np.asarray(y)
625624
type_of_target_y = type_of_target(y)
626625
allowed_target_types = ('binary', 'multiclass')

sklearn/model_selection/tests/test_split.py

+11
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,17 @@ def test_shuffle_stratifiedkfold():
493493
assert_not_equal(set(test0), set(test1))
494494
check_cv_coverage(kf0, X_40, y, groups=None, expected_n_splits=5)
495495

496+
# Ensure that we shuffle each class's samples with different
497+
# random_state in StratifiedKFold
498+
# See https://github.com/scikit-learn/scikit-learn/pull/13124
499+
X = np.arange(10)
500+
y = [0] * 5 + [1] * 5
501+
kf1 = StratifiedKFold(5, shuffle=True, random_state=0)
502+
kf2 = StratifiedKFold(5, shuffle=True, random_state=1)
503+
test_set1 = sorted([tuple(s[1]) for s in kf1.split(X, y)])
504+
test_set2 = sorted([tuple(s[1]) for s in kf2.split(X, y)])
505+
assert test_set1 != test_set2
506+
496507

497508
def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372
498509
# The digits samples are dependent: they are apparently grouped by authors

0 commit comments

Comments
 (0)