Skip to content

[MRG + 1] Set a random random_state at init to ensure deterministic randomness when random_state is None #7935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
24 changes: 18 additions & 6 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,15 @@ def __init__(self, n_splits, shuffle, random_state):

self.n_splits = n_splits
self.shuffle = shuffle
self.random_state = random_state
if shuffle and not isinstance(random_state,
(np.integer, numbers.Integral)):
# This is done to ensure that the multiple calls to split
# are random for each initialization of splitter but consistent
# across multiple calls for the same initialization.
self.random_state = check_random_state(
random_state).randint(np.iinfo(np.int32).max)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gaël had instead suggested the use of get_state. Apart from the ability to represent many more than 2**32 possible states, it is faster to set the state than to seed a new RandomState.

I'm okay with this though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing check_random_state(self.random_state).get_state() in all cases will ensure similar behaviour regardless of the form of the parameter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry could you elaborate a bit more? Get state gets me a tuple of the RandomState(...) object which represents it's state... Do you mean we store this at random_state? (That garbles the repr...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "garble"? Show a RandomState object? That's fine, isn't it?

else:
self.random_state = random_state

def split(self, X, y=None, groups=None):
"""Generate indices to split data into training and test set.
Expand Down Expand Up @@ -559,10 +567,7 @@ def __init__(self, n_splits=3, shuffle=False, random_state=None):
super(StratifiedKFold, self).__init__(n_splits, shuffle, random_state)

def _make_test_folds(self, X, y=None, groups=None):
if self.shuffle:
rng = check_random_state(self.random_state)
else:
rng = self.random_state
rng = check_random_state(self.random_state)
y = np.asarray(y)
n_samples = y.shape[0]
unique_y, y_inversed = np.unique(y, return_inverse=True)
Expand Down Expand Up @@ -922,7 +927,14 @@ def __init__(self, n_splits=10, test_size=0.1, train_size=None,
self.n_splits = n_splits
self.test_size = test_size
self.train_size = train_size
self.random_state = random_state
if isinstance(random_state, (np.integer, numbers.Integral)):
self.random_state = random_state
else:
# This is done to ensure that the multiple calls to split
# are random for each initialization of splitter but consistent
# across multiple calls for the same initialization.
self.random_state = check_random_state(
random_state).randint(np.iinfo(np.int32).max)

def split(self, X, y=None, groups=None):
"""Generate indices to split data into training and test set.
Expand Down
49 changes: 33 additions & 16 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_cross_validator_with_default_params():
skf = StratifiedKFold(n_splits)
lolo = LeaveOneGroupOut()
lopo = LeavePGroupsOut(p)
ss = ShuffleSplit(random_state=0)
ss = ShuffleSplit(random_state=42)
ps = PredefinedSplit([1, 1, 2, 2]) # n_splits = np of unique folds = 2

loo_repr = "LeaveOneOut()"
Expand All @@ -160,7 +160,7 @@ def test_cross_validator_with_default_params():
skf_repr = "StratifiedKFold(n_splits=2, random_state=None, shuffle=False)"
lolo_repr = "LeaveOneGroupOut()"
lopo_repr = "LeavePGroupsOut(n_groups=2)"
ss_repr = ("ShuffleSplit(n_splits=10, random_state=0, test_size=0.1, "
ss_repr = ("ShuffleSplit(n_splits=10, random_state=42, test_size=0.1, "
"train_size=None)")
ps_repr = "PredefinedSplit(test_fold=array([1, 1, 2, 2]))"

Expand Down Expand Up @@ -423,26 +423,43 @@ def test_shuffle_kfold_stratifiedkfold_reproducibility():
X2 = np.ones(16) # Not divisible by 3
y2 = [0] * 8 + [1] * 8

# random_state set to int with shuffle=True
kf = KFold(3, shuffle=True, random_state=0)
skf = StratifiedKFold(3, shuffle=True, random_state=0)

for cv in (kf, skf):
np.testing.assert_equal(list(cv.split(X, y)), list(cv.split(X, y)))
np.testing.assert_equal(list(cv.split(X2, y2)), list(cv.split(X2, y2)))

kf = KFold(3, shuffle=True)
skf = StratifiedKFold(3, shuffle=True)

for cv in (kf, skf):
for data in zip((X, X2), (y, y2)):
# random_state set to RandomState object with shuffle=True
kf2 = KFold(3, shuffle=True, random_state=np.random.RandomState(0))
skf2 = StratifiedKFold(3, shuffle=True,
random_state=np.random.RandomState(0))
# random_state not set with shuffle=True
kf3 = KFold(3, shuffle=True)
skf3 = StratifiedKFold(3, shuffle=True)

# 1) Test to ensure consistent behavior for multiple split calls
# irrespective of random_state
for cv in (kf, skf, kf2, skf2, kf3, skf3):
for data in ((X, y), (X2, y2)):
# Check that calling split twice yields the same results
np.testing.assert_equal(list(cv.split(*data)),
list(cv.split(*data)))

# 2) Tests to ensure different initilization produce different splits,
# when random_state is not set
kf1 = KFold(3, shuffle=True)
skf1 = StratifiedKFold(3, shuffle=True)
kf2 = KFold(3, shuffle=True)
skf2 = StratifiedKFold(3, shuffle=True)
for cv1, cv2 in ((kf1, kf2), (skf1, skf2)):
for data in ((X, y), (X2, y2)):
# For different initialisations, splits should not be same when
# random_state is not set.
try:
np.testing.assert_equal(list(cv.split(*data)),
list(cv.split(*data)))
np.testing.assert_equal(list(cv1.split(*data)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this is an ugly pattern but whatever... Is this the only place where we use this? Otherwise we might want to put this into utils?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for another PR? There are quite a few places (even outside model_selection) using np.testing.assert_*equal...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Though I was referring to the try: else: construct.

list(cv2.split(*data)))
except AssertionError:
pass
else:
raise AssertionError("The splits for data, %s, are same even "
"when random state is not set" % data)
raise AssertionError("When random_state is not set, the splits"
" are same for different initializations")


def test_shuffle_stratifiedkfold():
Expand Down