Skip to content

[architectural suggestion] Move random number generator to initializer in model_selection._split #6726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
DSLituiev opened this issue Apr 27, 2016 · 15 comments

Comments

@DSLituiev
Copy link
Contributor

DSLituiev commented Apr 27, 2016

This suggestion concerns random shuffling in the new model_selection module.

I faced a challenge in the following set up. I do a grid search with CV, and I want the CV reshuffling to be consistent for each parameter I am looping through. Now it seems impossible to do with model_selection.KFold, as copy.copy() and copy.deepcopy() lead to an error when called in following sequence:

import copy
import sklearn
y = np.random.randn(100)
n_folds = 10
kf_ = sklearn.model_selection.KFold(n_folds=n_folds, shuffle=True)
kf = kf_.split(y)
for tr, ts in copy.copy(kf):
    print((ts))

copying it earlier does not make sense, as the RNG is initialized only during kf_.split(y) call.

One solution is to specify the seed for each shuffling fold. Another fundamental solution is to refactor model_selection and move check_random_state(self.random_state) as here to the __init__ of the _BaseKFold, and then each kf_.split(y) will give consistently shuffled indices.

Versions

Python 3.5.1 (v3.5.1:37a07cee5969, Dec 5 2015, 21:12:44)
[GCC 4.2.1 (Apple Inc. build 5666) (dot 3)]
NumPy 1.11.0
SciPy 0.17.0
Scikit-Learn 0.18.dev0

@DSLituiev DSLituiev changed the title [architectural suggestion] Move random number generator to initializer in model_selection [architectural suggestion] Move random number generator to initializer in model_selection._split Apr 27, 2016
@DSLituiev
Copy link
Contributor Author

DSLituiev commented Apr 27, 2016

Note that the functionality I describe is already available in cross_validation module:

n_folds = 9
kf = sklearn.cross_validation.KFold(len(y), n_folds=n_folds, shuffle=True)

for tr, ts in kf:
    print((ts))
# is the same as:
for tr, ts in kf:
    print((ts))

@TomDLT
Copy link
Member

TomDLT commented Apr 28, 2016

What about the random_state parameter?

import numpy as np
from sklearn.model_selection import KFold
y = np.random.randn(100)
kf = KFold(n_folds=9, shuffle=True, random_state=0)

for tr, ts in kf.split(y):
    print((ts))
# is the same as:
print('')
for tr, ts in kf.split(y):
    print((ts))

cross_validation module is deprecated, you should use model_selection if available on your version.

@DSLituiev
Copy link
Contributor Author

DSLituiev commented May 1, 2016

I do use model_selection. This is actually a trick I am using now. But it looks more like a trick than a good practice. When I need different random states, I have to specify them now every time. The seed does not depend on data y, and I see no reason not to factor it out into the initializer, other than intention to generate different shuffling with the same object.

@amueller amueller added the Bug label Oct 11, 2016
@amueller amueller added this to the 0.18.1 milestone Oct 11, 2016
@amueller
Copy link
Member

So you want the splits to be random, but consistent across all parameters?
And that was the case in 0.17 but not any more in 0.18... That is indeed a bit strange.
I think I agree with moving the check_random_state.

If only I read this before the release :-/ (also, we merged model_selection in April already? I'm getting so old...)

ping @jnothman @raghavrv for their opinions.

@jnothman
Copy link
Member

jnothman commented Oct 11, 2016

Yes, it is a problem that split() is run in GridSearch et al per parameter setting. Yuck.

We should call splits = list(cv.split(X, y)) once.

Whether or not we also move the random_state initialisation to __init__ is a separate issue. But certainly our current usage of this is bad.

(And yes, you can create the list and should not need to copy or deepcopy in most cases.)

Maybe we needed the new CV splitters to be an experimental release before deprecating the old ones.

@amueller and @ogrisel, given that we've found some substantial issues (this and pickling at least), can we set some goals for 0.18.1 in terms of timescale? Aim to RC at beginning of November, for instance? And then focus almost exclusively on release-critical issues? I also think the 0.18.1 release plan should be noted on the mailing list.

@amueller
Copy link
Member

I have to think about this issue again, I think I didn't really understand it. Not sure how the behavior was before.
I'm starting a new job November first (hopefully - or get kicked out of the country), so doing 0.18.1 in October would work better for me.

@raghavrv
Copy link
Member

Yes, it is a problem that split() is run in GridSearch et al per parameter setting.

@jnothman Unless the user is using a cv which will produce different train test sets for each split() call I don't see the problem here. It was like that before in cross_validation and we decided to preserve it as such.

from sklearn.model_selection import ShuffleSplit
from sklearn.cross_validation import ShuffleSplit as OldShuffleSplit

import numpy as np

y = np.random.RandomState(0).randn(100)
n_iter = 10
ss = ShuffleSplit(n_splits=n_iter, random_state=0)
old_ss_iter = OldShuffleSplit(len(y), n_iter=n_iter, random_state=0)

new_ss_iter_1 = list(ss.split(y))
new_ss_iter_2 = list(ss.split(y))

# Let's change the random state and make a 3rd split(y) call.
ss.random_state=42
new_ss_iter_3 = list(ss.split(y))

for trte1, trte2, trte3, trte4 in zip(
        new_ss_iter_1, new_ss_iter_2, new_ss_iter_3, old_ss_iter):
    np.testing.assert_equal(trte1, trte2)
    np.testing.assert_equal(trte1, trte4)
    # This will fail as random states are not same
    np.testing.assert_equal(trte2, trte3)

And you will observe the same behavior with KFold too...

y = np.random.RandomState(0).randn(100)
n_splits = 10

kf = KFold(n_splits=n_splits, shuffle=True, random_state=0)
old_kf_iter = OldKFold(len(y), n_folds=n_splits, shuffle=True, random_state=0)

new_kf_iter_1 = list(kf.split(y))
new_kf_iter_2 = list(kf.split(y))

# Let's change the random state and make a 3rd split(y) call.
kf.random_state=42
new_kf_iter_3 = list(kf.split(y))

for trte1, trte2, trte3, trte4 in zip(
        new_kf_iter_1, new_kf_iter_2, new_kf_iter_3, old_kf_iter):
    np.testing.assert_equal(trte1, trte2)
    np.testing.assert_equal(trte1, trte4)
    # This will fail as random states are not same
    np.testing.assert_equal(trte2, trte3)

DSLituiev: ... moved to the __init__ of the _BaseKFold, and then each kf_.split(y) will give consistently shuffled indices.
Andy: So you want the splits to be random, but consistent across all parameters? And that was the case in 0.17 but not any more in 0.18... That is indeed a bit strange.

It does give consistently shuffled indices (see above code snippet for proof) now only because it is not at __init__. If the rng is set at __init__ to say have a random state of 42 and at the end of one call to split(X, y), it will most probably not be at random state of 42 any more and hence will produce a different train / test set in the subsequent call to split() that is inconsistent with the 1st...

@DSLituiev Could you share a sample code that works for you in cross_validation's KFold but not in the model_selection's? Unless I am misunderstanding something, there is no issue here...

@jnothman
Copy link
Member

jnothman commented Oct 12, 2016

I don't think that you should assume most users set random_state, but may have assumed that results across folds for different grid search parameter settings will be comparable. Formerly, all in-built CV iterators were designed to be invariant for each call to __iter__, so this assumption was uphelp, for those iterators at least. So in this sense, setting shuffle=True without fixing random_state will now break the user's expectations in grid search, with seemingly-equivalent code.

@DSLituiev's particular issue can be solved by documentation that if someone wants consistent split they should use list(cv.split(X, y)). But they do not fix the issue in grid search, which comes from a combination of where you check_random_state and where you call split(). We need to choose one or both of those fixes, and we need to document and test these assumptions about variation or none wrt both split and search.

I now think there is also a issue in CVIterableWrapper because it assumes that the provided iterable is reentrant. While this too is consistent with former use of cv in GridSearchCV, it needs to be clearly documented that not any iterable will do, but something that produces the same sequence each time iter is called: effectively a (perhaps dynamically generated) sequence. Or else we can call split() exactly once and remove this requirement.

@jnothman
Copy link
Member

Being practical about timescale, @amueller, doing 0.18.1 in October means releasing in November. There are two weeks of October left

@raghavrv
Copy link
Member

So in this sense, setting shuffle=True without fixing random_state will now break the user's expectations in grid search, with seemingly-equivalent code.

Ah. I understand the issue... Thanks for the clarification.

Or else we can call split() exactly once and remove this requirement.

So IIUC. This + documenting that without random_state set the splitters will not be consistent and users will have to do list(cv.split...) to get consistent splits would fix all the concerns raised in this issue fully correct? I'll send a PR soon...

@jnothman
Copy link
Member

Well, it's clear there are a few ways to define the problem. I think that
solution is good, under the assumption that we always want the same splits
to be used in parameter searches.

On 13 October 2016 at 01:30, Raghav RV notifications@github.com wrote:

So in this sense, setting shuffle=True without fixing random_state will
now break the user's expectations in grid search, with seemingly-equivalent
code.

Ah. I understand the issue... Thanks for the clarification.

Or else we can call split() exactly once and remove this requirement.

So IIUC. This + documenting that without random_state set the splitters
will not be consistent and users will have to do list(cv.split...) to get
consistent splits would fix all the concerns raised in this issue fully
correct? I'll send a PR soon...


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#6726 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz60Al1x3PJk9NtoUk3FQj4w36TG-zks5qzO7ogaJpZM4IRRsA
.

@raghavrv
Copy link
Member

PR at #7660

@raghavrv
Copy link
Member

I'm reopening as #7660 seems to not fix this satisfactorily... There is an alternative fix at #7935 which attempts to fix it in a better way...

@raghavrv raghavrv reopened this Nov 24, 2016
@raghavrv raghavrv removed this from the 0.18.2 milestone Nov 30, 2016
@raghavrv
Copy link
Member

I'm closing this as we decided that the current behavior after #7660 is sufficient enough.

@jnothman
Copy link
Member

jnothman commented Jun 15, 2017 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants