Skip to content

Commit 75d6005

Browse files
raghavrvjnothman
authored andcommitted
[MRG] Add few more tests + Documentation for re-entrant cross-validation estimators (#7823)
* DOC Add NOTE that unless random_state is set, split will not be identical * TST use np.testing.assert_equal for nested lists/arrays * TST Make sure cv param can be a generator * DOC rank_ becomes a link when rendered * Use test_... * Remove blank line; Add if shuffle is True * Fix tests * Explicitly test for GeneratorType * TST Add the else clause * TST Add comment on usage of np.testing.assert_array_equal * TYPO * MNT Remove if ; * Address Joel's comments * merge the identical points in doc * DOC address Andy's comments * Move comment to before the check for generator type
1 parent ac1b048 commit 75d6005

File tree

5 files changed

+76
-19
lines changed

5 files changed

+76
-19
lines changed

doc/modules/cross_validation.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,7 @@ to shuffle the data indices before splitting them. Note that:
725725
shuffling will be different every time ``KFold(..., shuffle=True)`` is
726726
iterated. However, ``GridSearchCV`` will use the same shuffling for each set
727727
of parameters validated by a single call to its ``fit`` method.
728-
* To ensure results are repeatable (*on the same platform*), use a fixed value
729-
for ``random_state``.
728+
* To get identical results for each split, set ``random_state`` to an integer.
730729

731730
Cross validation and model selection
732731
====================================

sklearn/model_selection/_search.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ class GridSearchCV(BaseSearchCV):
924924
For instance the below given table
925925
926926
+------------+-----------+------------+-----------------+---+---------+
927-
|param_kernel|param_gamma|param_degree|split0_test_score|...|..rank...|
927+
|param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...|
928928
+============+===========+============+=================+===+=========+
929929
| 'poly' | -- | 2 | 0.8 |...| 2 |
930930
+------------+-----------+------------+-----------------+---+---------+

sklearn/model_selection/_split.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def split(self, X, y=None, groups=None):
8383
8484
test : ndarray
8585
The testing set indices for that split.
86+
87+
Note
88+
----
89+
Randomized CV splitters may return different results for each call of
90+
split. You can make the results identical by setting ``random_state``
91+
to an integer.
8692
"""
8793
X, y, groups = indexable(X, y, groups)
8894
indices = np.arange(_num_samples(X))
@@ -308,6 +314,12 @@ def split(self, X, y=None, groups=None):
308314
309315
test : ndarray
310316
The testing set indices for that split.
317+
318+
Note
319+
----
320+
Randomized CV splitters may return different results for each call of
321+
split. You can make the results identical by setting ``random_state``
322+
to an integer.
311323
"""
312324
X, y, groups = indexable(X, y, groups)
313325
n_samples = _num_samples(X)
@@ -567,10 +579,7 @@ def __init__(self, n_splits=3, shuffle=False, random_state=None):
567579
super(StratifiedKFold, self).__init__(n_splits, shuffle, random_state)
568580

569581
def _make_test_folds(self, X, y=None):
570-
if self.shuffle:
571-
rng = check_random_state(self.random_state)
572-
else:
573-
rng = self.random_state
582+
rng = self.random_state
574583
y = np.asarray(y)
575584
n_samples = y.shape[0]
576585
unique_y, y_inversed = np.unique(y, return_inverse=True)
@@ -645,6 +654,12 @@ def split(self, X, y, groups=None):
645654
646655
test : ndarray
647656
The testing set indices for that split.
657+
658+
Note
659+
----
660+
Randomized CV splitters may return different results for each call of
661+
split. You can make the results identical by setting ``random_state``
662+
to an integer.
648663
"""
649664
y = check_array(y, ensure_2d=False, dtype=None)
650665
return super(StratifiedKFold, self).split(X, y, groups)
@@ -726,6 +741,12 @@ def split(self, X, y=None, groups=None):
726741
727742
test : ndarray
728743
The testing set indices for that split.
744+
745+
Note
746+
----
747+
Randomized CV splitters may return different results for each call of
748+
split. You can make the results identical by setting ``random_state``
749+
to an integer.
729750
"""
730751
X, y, groups = indexable(X, y, groups)
731752
n_samples = _num_samples(X)
@@ -1164,6 +1185,12 @@ def split(self, X, y=None, groups=None):
11641185
11651186
test : ndarray
11661187
The testing set indices for that split.
1188+
1189+
Note
1190+
----
1191+
Randomized CV splitters may return different results for each call of
1192+
split. You can make the results identical by setting ``random_state``
1193+
to an integer.
11671194
"""
11681195
X, y, groups = indexable(X, y, groups)
11691196
for train, test in self._iter_indices(X, y, groups):
@@ -1578,6 +1605,12 @@ def split(self, X, y, groups=None):
15781605
15791606
test : ndarray
15801607
The testing set indices for that split.
1608+
1609+
Note
1610+
----
1611+
Randomized CV splitters may return different results for each call of
1612+
split. You can make the results identical by setting ``random_state``
1613+
to an integer.
15811614
"""
15821615
y = check_array(y, ensure_2d=False, dtype=None)
15831616
return super(StratifiedShuffleSplit, self).split(X, y, groups)

sklearn/model_selection/tests/test_search.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from itertools import chain, product
88
import pickle
99
import sys
10+
from types import GeneratorType
1011
import re
1112

1213
import numpy as np
@@ -1070,16 +1071,10 @@ def test_search_cv_results_rank_tie_breaking():
10701071
cv_results['mean_test_score'][1])
10711072
assert_almost_equal(cv_results['mean_train_score'][0],
10721073
cv_results['mean_train_score'][1])
1073-
try:
1074-
assert_almost_equal(cv_results['mean_test_score'][1],
1075-
cv_results['mean_test_score'][2])
1076-
except AssertionError:
1077-
pass
1078-
try:
1079-
assert_almost_equal(cv_results['mean_train_score'][1],
1080-
cv_results['mean_train_score'][2])
1081-
except AssertionError:
1082-
pass
1074+
assert_false(np.allclose(cv_results['mean_test_score'][1],
1075+
cv_results['mean_test_score'][2]))
1076+
assert_false(np.allclose(cv_results['mean_train_score'][1],
1077+
cv_results['mean_train_score'][2]))
10831078
# 'min' rank should be assigned to the tied candidates
10841079
assert_almost_equal(search.cv_results_['rank_test_score'], [1, 1, 3])
10851080

@@ -1421,6 +1416,33 @@ def test_grid_search_cv_splits_consistency():
14211416
cv=KFold(n_splits=n_splits))
14221417
gs2.fit(X, y)
14231418

1419+
# Give generator as a cv parameter
1420+
assert_true(isinstance(KFold(n_splits=n_splits,
1421+
shuffle=True, random_state=0).split(X, y),
1422+
GeneratorType))
1423+
gs3 = GridSearchCV(LinearSVC(random_state=0),
1424+
param_grid={'C': [0.1, 0.2, 0.3]},
1425+
cv=KFold(n_splits=n_splits, shuffle=True,
1426+
random_state=0).split(X, y))
1427+
gs3.fit(X, y)
1428+
1429+
gs4 = GridSearchCV(LinearSVC(random_state=0),
1430+
param_grid={'C': [0.1, 0.2, 0.3]},
1431+
cv=KFold(n_splits=n_splits, shuffle=True,
1432+
random_state=0))
1433+
gs4.fit(X, y)
1434+
1435+
def _pop_time_keys(cv_results):
1436+
for key in ('mean_fit_time', 'std_fit_time',
1437+
'mean_score_time', 'std_score_time'):
1438+
cv_results.pop(key)
1439+
return cv_results
1440+
1441+
# Check if generators are supported as cv and
1442+
# that the splits are consistent
1443+
np.testing.assert_equal(_pop_time_keys(gs3.cv_results_),
1444+
_pop_time_keys(gs4.cv_results_))
1445+
14241446
# OneTimeSplitter is a non-re-entrant cv where split can be called only
14251447
# once if ``cv.split`` is called once per param setting in GridSearchCV.fit
14261448
# the 2nd and 3rd parameter will not be evaluated as no train/test indices

sklearn/model_selection/tests/test_split.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,11 @@ def test_shuffle_kfold_stratifiedkfold_reproducibility():
446446

447447
for cv in (kf, skf):
448448
for data in zip((X, X2), (y, y2)):
449+
# Test if the two splits are different
450+
# numpy's assert_equal properly compares nested lists
449451
try:
450-
np.testing.assert_equal(list(cv.split(*data)),
451-
list(cv.split(*data)))
452+
np.testing.assert_array_equal(list(cv.split(*data)),
453+
list(cv.split(*data)))
452454
except AssertionError:
453455
pass
454456
else:
@@ -1188,6 +1190,7 @@ def test_cv_iterable_wrapper():
11881190
# results
11891191
kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y)
11901192
kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
1193+
# numpy's assert_array_equal properly compares nested lists
11911194
np.testing.assert_equal(list(kf_randomized_iter_wrapped.split(X, y)),
11921195
list(kf_randomized_iter_wrapped.split(X, y)))
11931196

0 commit comments

Comments
 (0)