Skip to content

FIX make creation of dataset deterministic in SGD #19716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

PierreAttard
Copy link
Contributor

Reference Issues/PRs

Fixes #19603

What does this implement/fix? Explain your changes.

check_random_state moved up before make_dataset in order to use the same random state during the whole process.
This change occurs in classes SGDRegressor, SparseSGDRegressor, fit method.

Any other comments?

I had to adapt the existing unitest test_validation_set_not_used_for_training in file test_sgd.py :

With this change, the random state is "used" one time before the following line :

validation_mask = self._make_validation_split(y)

So, I reproduce this effect in the unitest with those changes :

        rng = np.random.RandomState(seed)
        rng.randint(1, np.iinfo(np.int32).max)
        cv = ShuffleSplit(test_size=validation_fraction,
                          random_state=rng)

I am not totally satisfied, what do you think about ?

If you're OK, I add text in what news doc and add a non-regression test to check that the model is fully deterministic when seeding the random state like said @ogrisel

Thanks in advance for responses !

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a non-regression test that check that calling SGDRegressor twice on the same data without waiting for convergence (for instance using max_iter=1, gives the same model.coef_ and model.intercept_ attribute?

cv = ShuffleSplit(test_size=validation_fraction,
random_state=seed)
random_state=rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand the motivation behind those changes. Passing seed directly was fine, no?

Copy link
Member

@ogrisel ogrisel Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I had not read the description of the PR...

I think we need to rework the test to make it less dependent on internal details. Not sure how though...

Copy link
Member

@ogrisel ogrisel Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably just move the creation of the validation split (the lines that define validation_mask and validation_score_cb) just after the call to check_random_state.

It's still dependent a bit on internal details but the test would appear less convoluted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I explain in the PR message.

Any other comments?

I had to adapt the existing unitest test_validation_set_not_used_for_training in file test_sgd.py :

With this change, the random state is "used" one time before the following line :

validation_mask = self._make_validation_split(y)

So, I reproduce this effect in the unitest with those changes :

        rng = np.random.RandomState(seed)
        rng.randint(1, np.iinfo(np.int32).max)
        cv = ShuffleSplit(test_size=validation_fraction,
                          random_state=rng)

I am not totally satisfied, what do you think about ?

If you're OK, I add text in what news doc and add a non-regression test to check that the model is fully deterministic when seeding the random state like said @ogrisel

Thanks in advance for responses !

If you use the seed, the results will be different because inside _fit_regressor method, the random object is run a first time with the given seed before the use of _make_validation_split. That is why in the test, in order to reproduce that behavior, I create rng which is run a first time before the validaiton split. But I am not really convinced.

Sorry if I am not clear !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The run term is probably wrong but means in this case :

rng.randint(1, np.iinfo(np.int32).max)

Copy link
Member

@ogrisel ogrisel Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright I get it now...

Still I don't like the way this test is written. I think we should rewrite it completely. For instance by fitting a model on a random target, with and without early stopping on a validation split. The accuracy measured on (X_train, y_train) and the value of the n_iter_ attribute of the model without validation set early stopping should be much higher than the model with early stopping.

It would not test exactly the same thing but it would be more maintainable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it such that there's no change in the model. I explained below that SGD regressors have no use of the seed attribute of the generated dataset. So I think it's better to just document that in the PR and leave this test as is. In addition the test you propose is not guaranteed, but only quite probable.

@ogrisel
Copy link
Member

ogrisel commented Mar 19, 2021

You will also need to document the fix in doc/whats_new/v0.24.rst for the 0.24.2 release.

@ogrisel ogrisel modified the milestones: 0.24, 0.24.2 Mar 19, 2021
@ogrisel
Copy link
Member

ogrisel commented Mar 30, 2021

@PierreAttard gentle reminder that there are still some comments to address in this PR ;)

@PierreAttard
Copy link
Contributor Author

Yes, indeed ! I'll try to continue the work next week !

@glemaitre glemaitre self-requested a review April 21, 2021 21:17
@glemaitre
Copy link
Member

I replaced the test as @ogrisel mentioned and added an entry in what's new.

@ogrisel can you have a quick look to see if the test reflect what you had in mind.

@PierreAttard
Copy link
Contributor Author

Thanks a lot @glemaitre . I didn't have much time. I check as soon as possible.

@glemaitre
Copy link
Member

I see that the CIs are failing. It should be linked to the fact that we changed the way the random state is working.
This is not a big deal, we need to mention it in the what's new but I am thinking that it could be better to move this PR to be included in 1.0 instead of 0.24.2 because it could be weird to have a change of behaviour in a bug fix release.

@ogrisel WDYT?

@glemaitre glemaitre modified the milestones: 0.24.2, 1.0 Apr 22, 2021
@PierreAttard
Copy link
Contributor Author

PierreAttard commented May 1, 2021

Hi @glemaitre, I launched the test test_validation_set_not_used_for_training several times.
Most of time, the 4 tests pass but some times, one of them does not pass, and only for SGDClassifier estimator.

Did you note that too ?
Shouldn't we make it reproducible ?

Sorry again for the delay of the review.

@glemaitre glemaitre self-assigned this Jul 27, 2021
@glemaitre
Copy link
Member

I solved the conflict in this PR and make the test deterministic by creating a deterministic dataset indeed.

glemaitre
glemaitre previously approved these changes Jul 27, 2021
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@glemaitre glemaitre removed their assignment Jul 27, 2021
@glemaitre glemaitre changed the title check_random_state moved up in order to use it in make_dataset and improve reproduction FIX make creation of dataset deterministic in SGD Jul 29, 2021
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise LGTM

@adrinjalali
Copy link
Member

Please merge with main and fix the merge conflict as well :)

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this fix, @PierreAttard.

Here are some suggestions, a few of which might be corrected when merging main in this PR branch.

@jjerphan
Copy link
Member

jjerphan commented Nov 8, 2021

Hello @PierreAttard,

Do you still have some time to work on this PR?

Signed-off-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
Signed-off-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
@jeremiedbb
Copy link
Member

jeremiedbb commented Mar 23, 2022

Turns out that

  • make_dataset creates a SequentialDataset which has a seed attribute. This attr is only used to perform random sampling through the random method of the dataset.
  • Shuffling of the datset is done directly passing a pointer to a (de facto) mutable seed to the shuffle method of the dataset.

Hence, passing the random_state to make_dataset in _fit_regressor only has an effect on the random sampling part.
However, in _plain_sgd (called by _fit_regressor) there's no random sampling on the dataset (only shuffling). So this change has absolutely no effect.

I updated the doc of make_dataset and of the seed attribute of the dataset.

@jeremiedbb jeremiedbb dismissed stale reviews from jjerphan and glemaitre March 23, 2022 16:23

everything's changed

@jjerphan jjerphan self-requested a review March 23, 2022 16:24
@ogrisel
Copy link
Member

ogrisel commented Mar 31, 2022

I added a new non regression test and I confirm that there was no non-determinism bug on main and that this PR does not change this behavior.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I closed #19603 because there is no bug in the end but +1 for merging this PR to make the code less surprising.

@ogrisel
Copy link
Member

ogrisel commented Mar 31, 2022

Note, I have checked that the new test passes with SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all".

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @PierreAttard and @jeremiedbb.

@jeremiedbb
Copy link
Member

The failure should be unrelated to the PR (see #23014) but it's not triggered in other PRs. I'm starting it again to see if it's deterministic

@jeremiedbb
Copy link
Member

That random failure was really weird...
Let's still merge this PR and try to find what's going on in #23014

@jeremiedbb jeremiedbb merged commit b4da3b4 into scikit-learn:main Apr 2, 2022
@jeremiedbb
Copy link
Member

Thanks @PierreAttard !

glemaitre added a commit to glemaitre/scikit-learn that referenced this pull request Apr 6, 2022
…19716)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Apr 29, 2022
…19716)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

_fit_regressor in stochastic_gradient.py does not use random state for call to make_dataset
7 participants