Skip to content

Commit d2b405c

Browse files
qinhanmin2014jnothman
authored andcommitted
FIX Shuffle each class's samples with different random_state in StratifiedKFold (scikit-learn#13124)
* Enable StratifiedKFold to produce different splits * what's new * redundant statement * update what's new * redundant comment * add a test * move what's new entry * review comment * review comment
1 parent 7b136e9 commit d2b405c

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

doc/whats_new/v0.20.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,26 @@
77
Version 0.20.3
88
==============
99

10+
**July 29, 2019**
11+
12+
This is a bug-fix release with some bug fixes applied to version 0.20.3.
13+
14+
Changelog
15+
---------
16+
17+
:mod:`sklearn.model_selection`
18+
..............................
19+
20+
- |Fix| Fixed a bug where :class:`model_selection.StratifiedKFold`
21+
shuffles each class's samples with the same ``random_state``,
22+
making ``shuffle=True`` ineffective.
23+
:issue:`13124` by :user:`Hanmin Qin <qinhanmin2014>`.
24+
25+
.. _changes_0_20_3:
26+
27+
Version 0.20.3
28+
==============
29+
1030
**March 1, 2019**
1131

1232
This is a bug-fix release with some minor documentation improvements and

sklearn/model_selection/_split.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,7 @@ class StratifiedKFold(_BaseKFold):
583583
``n_splits`` default value will change from 3 to 5 in v0.22.
584584
585585
shuffle : boolean, optional
586-
Whether to shuffle each stratification of the data before splitting
587-
into batches.
586+
Whether to shuffle each class's samples before splitting into batches.
588587
589588
random_state : int, RandomState instance or None, optional, default=None
590589
If int, random_state is the seed used by the random number generator;
@@ -626,7 +625,7 @@ def __init__(self, n_splits='warn', shuffle=False, random_state=None):
626625
super(StratifiedKFold, self).__init__(n_splits, shuffle, random_state)
627626

628627
def _make_test_folds(self, X, y=None):
629-
rng = self.random_state
628+
rng = check_random_state(self.random_state)
630629
y = np.asarray(y)
631630
type_of_target_y = type_of_target(y)
632631
allowed_target_types = ('binary', 'multiclass')

sklearn/model_selection/tests/test_split.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,17 @@ def test_shuffle_stratifiedkfold():
504504
assert_not_equal(set(test0), set(test1))
505505
check_cv_coverage(kf0, X_40, y, groups=None, expected_n_splits=5)
506506

507+
# Ensure that we shuffle each class's samples with different
508+
# random_state in StratifiedKFold
509+
# See https://github.com/scikit-learn/scikit-learn/pull/13124
510+
X = np.arange(10)
511+
y = [0] * 5 + [1] * 5
512+
kf1 = StratifiedKFold(5, shuffle=True, random_state=0)
513+
kf2 = StratifiedKFold(5, shuffle=True, random_state=1)
514+
test_set1 = sorted([tuple(s[1]) for s in kf1.split(X, y)])
515+
test_set2 = sorted([tuple(s[1]) for s in kf2.split(X, y)])
516+
assert test_set1 != test_set2
517+
507518

508519
def test_kfold_can_detect_dependent_samples_on_digits(): # see #2372
509520
# The digits samples are dependent: they are apparently grouped by authors

0 commit comments

Comments
 (0)