Skip to content

TST use global_random_seed in sklearn/decomposition/tests/test_sparse_pca.py #31213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 17, 2025

Conversation

DeaMariaLeon
Copy link
Contributor

Reference Issues/PRs

Towards #22827

What does this implement/fix? Explain your changes.

Any other comments?

I wonder why test_mini_batch_fit_transform is kept. It is skipped all the time. It comes from PR #12253

cc @glemaitre

test_fit_transform
test_fit_transform_parallel
test_fit_transform_tall
test_initialization
test_scaling_fit_transform
test_pca_vs_spca
test_sparse_pca_inverse_transform
test_transform_inverse_transform_round_trip
Copy link

github-actions bot commented Apr 16, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0bb7336. Link to the linter CI: here

@DeaMariaLeon
Copy link
Contributor Author

Testing with all the seed values were passing locally.

Copy link
Member

@lucyleeow lucyleeow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

(no idea about test_mini_batch_fit_transform though ...)

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can add global_random_seed to

  • test_transform_nan
  • test_sparse_pca_numerical_consistency

Comment on lines 195 to 198
spca = SparsePCA(
alpha=0, ridge_alpha=0, n_components=2, random_state=global_random_seed
)
pca = PCA(n_components=2, random_state=global_random_seed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not required to start from the same seedin the test, is it ?

Suggested change
spca = SparsePCA(
alpha=0, ridge_alpha=0, n_components=2, random_state=global_random_seed
)
pca = PCA(n_components=2, random_state=global_random_seed)
spca = SparsePCA(
alpha=0, ridge_alpha=0, n_components=2, random_state=rng
)
pca = PCA(n_components=2, random_state=rng)

Copy link
Contributor Author

@DeaMariaLeon DeaMariaLeon Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you.
I need to ask (maybe @glemaitre ?) why sometimes it's ok to use rng (from np.random.RandomState(global_random_seed)).. and sometimes the same seed -global_random_seeddirectly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_random_seed is a number like 0, 1, 42, ...
np.random.RandomState is an object with a state. When you ask it to generate a random value its state changes, so the next time you ast it again it gives a different value.
If you pass an integer as random state to an estimator, it's used to create and seed a RandomState object. Therefore, each time it's fitted, it starts from the same seed and produces the same results.

So to answer your question, it depends if need to generate multiple times the exact same sequence of generated value or if it doesn't matter.
For instance, if you want to compare the results of an estimator between float32 and float64 input, you want that both estimators use the same sequence of generated values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much @jeremiedbb... I see that my question wasn't clear.

In the file modified with this PR, some tests set SparcePCA parameter random_state as random_state=rng, where rng = np.random.RandomState(global_random_seed). For example test_fit_transform_tall.

But some tests set the random_state parameter as random_state=global_random_seed. Like test_sparse_pca_numerical_consistency.

I guess it doesn't matter as there is this check_random_state(seed) that "Turn seed into a np.random.RandomState instance".

In the end they might both be the same, except that changing random_state of test_sparse_pca_numerical_consistency to random_state=rng was failing.

Anyway, your previous explanation confirmed my understanding. :-)

@jeremiedbb
Copy link
Member

I wonder why test_mini_batch_fit_transform is kept. It is skipped all the time.

Let's just remove it. It's useless to keep dead code in the code base.

@DeaMariaLeon
Copy link
Contributor Author

DeaMariaLeon commented Apr 17, 2025

I think you can add global_random_seed to

  • test_transform_nan
  • test_sparse_pca_numerical_consistency

I added the seed to those.
test_sparse_pca_numerical_consistency was failing a lot and I tried many things.. maybe now it's too sparse.. ? - the "many things" were increasing rtol, adding atol, etc.

test_fit_transform
test_fit_transform_parallel
test_transform_nan
test_fit_transform_tall
test_initialization
test_scaling_fit_transform
test_pca_vs_spca
test_sparse_pca_numerical_consistency
test_sparse_pca_inverse_transform
test_transform_inverse_transform_round_trip
@jeremiedbb
Copy link
Member

test_sparse_pca_numerical_consistency was failing a lot and I tried many things.. maybe now it's too sparse.. ? - the "many things" were increasing rtol, adding atol, etc.

Looking a bit, it seems to be an issue with the convergence criterion of DictionaryLearning (which is used by SparsePCA). It's not protected against rounding errors leading to negative values, so it may happen that in one case the algorithm keeps running for a few more iterations than in the other case.

It's not a major issue and not really related to this PR so I just switched the dataset for a more easy one for SparsePCA.

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@jeremiedbb jeremiedbb enabled auto-merge (squash) April 17, 2025 18:11
@jeremiedbb jeremiedbb merged commit ceac4a8 into scikit-learn:main Apr 17, 2025
34 checks passed
@DeaMariaLeon DeaMariaLeon deleted the tests branch April 18, 2025 06:34
lucyleeow pushed a commit to EmilyXinyi/scikit-learn that referenced this pull request Apr 23, 2025
…e_pca.py` (scikit-learn#31213)

Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants