-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG + 1] Set a random random_state at init to ensure deterministic randomness when random_state is None
#7935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5dd260c
75448b9
6f5947d
d05efd4
374d2ed
d3db617
c762363
318312e
a637c99
bc9f22c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,7 +151,7 @@ def test_cross_validator_with_default_params(): | |
skf = StratifiedKFold(n_splits) | ||
lolo = LeaveOneGroupOut() | ||
lopo = LeavePGroupsOut(p) | ||
ss = ShuffleSplit(random_state=0) | ||
ss = ShuffleSplit(random_state=42) | ||
ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2 | ||
|
||
loo_repr = "LeaveOneOut()" | ||
|
@@ -160,7 +160,7 @@ def test_cross_validator_with_default_params(): | |
skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)" | ||
lolo_repr = "LeaveOneGroupOut()" | ||
lopo_repr = "LeavePGroupsOut(n_groups=2)" | ||
ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, test_size=0.1, " | ||
ss_repr = ("ShuffleSplit(n_splits=10, random_state=42, test_size=0.1, " | ||
"train_size=None)") | ||
ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))" | ||
|
||
|
@@ -423,26 +423,43 @@ def test_shuffle_kfold_stratifiedkfold_reproducibility(): | |
X2 = np.ones(16) # Not divisible by 3 | ||
y2 = [0] * 8 + [1] * 8 | ||
|
||
# random_state set to int with shuffle=True | ||
kf = KFold(3, shuffle=True, random_state=0) | ||
skf = StratifiedKFold(3, shuffle=True, random_state=0) | ||
|
||
for cv in (kf, skf): | ||
np.testing.assert_equal(list(cv.split(X, y)), list(cv.split(X, y))) | ||
np.testing.assert_equal(list(cv.split(X2, y2)), list(cv.split(X2, y2))) | ||
|
||
kf = KFold(3, shuffle=True) | ||
skf = StratifiedKFold(3, shuffle=True) | ||
|
||
for cv in (kf, skf): | ||
for data in zip((X, X2), (y, y2)): | ||
# random_state set to RandomState object with shuffle=True | ||
kf2 = KFold(3, shuffle=True, random_state=np.random.RandomState(0)) | ||
skf2 = StratifiedKFold(3, shuffle=True, | ||
random_state=np.random.RandomState(0)) | ||
# random_state not set with shuffle=True | ||
kf3 = KFold(3, shuffle=True) | ||
skf3 = StratifiedKFold(3, shuffle=True) | ||
|
||
# 1) Test to ensure consistent behavior for multiple split calls | ||
# irrespective of random_state | ||
for cv in (kf, skf, kf2, skf2, kf3, skf3): | ||
for data in ((X, y), (X2, y2)): | ||
# Check that calling split twice yields the same results | ||
np.testing.assert_equal(list(cv.split(*data)), | ||
list(cv.split(*data))) | ||
|
||
# 2) Tests to ensure different initilization produce different splits, | ||
# when random_state is not set | ||
kf1 = KFold(3, shuffle=True) | ||
skf1 = StratifiedKFold(3, shuffle=True) | ||
kf2 = KFold(3, shuffle=True) | ||
skf2 = StratifiedKFold(3, shuffle=True) | ||
for cv1, cv2 in ((kf1, kf2), (skf1, skf2)): | ||
for data in ((X, y), (X2, y2)): | ||
# For different initialisations, splits should not be same when | ||
# random_state is not set. | ||
try: | ||
np.testing.assert_equal(list(cv.split(*data)), | ||
list(cv.split(*data))) | ||
np.testing.assert_equal(list(cv1.split(*data)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still think this is an ugly pattern but whatever... Is this the only place where we use this? Otherwise we might want to put this into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe for another PR? There are quite a few places (even outside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Though I was referring to the |
||
list(cv2.split(*data))) | ||
except AssertionError: | ||
pass | ||
else: | ||
raise AssertionError("The splits for data, %s, are same even " | ||
"when random state is not set" % data) | ||
raise AssertionError("When random_state is not set, the splits" | ||
" are same for different initializations") | ||
|
||
|
||
def test_shuffle_stratifiedkfold(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gaël had instead suggested the use of
get_state
. Apart from the ability to represent many more than 2**32 possible states, it is faster to set the state than to seed a newRandomState
.I'm okay with this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing
check_random_state(self.random_state).get_state()
in all cases will ensure similar behaviour regardless of the form of the parameter.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry could you elaborate a bit more? Get state gets me a tuple of the
RandomState(...)
object which represents it's state... Do you mean we store this atrandom_state
? (That garbles therepr
...)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "garble"? Show a RandomState object? That's fine, isn't it?