Skip to content

TST use global_random_seed in sklearn/decomposition/tests/test_pca.py #26403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8b0623e
Update test_pca.py
sply88 May 19, 2023
e83a0c6
Update test_pca.py
sply88 May 19, 2023
bd492a8
Update test_pca.py
sply88 May 19, 2023
a6b77ad
Update test_pca.py
sply88 May 19, 2023
3c5074c
Update test_pca.py
sply88 May 19, 2023
8ced911
Update test_pca.py
sply88 May 19, 2023
4fb1dda
Update test_pca.py
sply88 May 19, 2023
7565e76
Update test_pca.py
sply88 May 19, 2023
3e3fefa
Update test_pca.py
sply88 May 19, 2023
d8c8030
Update test_pca.py
sply88 May 19, 2023
ae793d8
Update test_pca.py
sply88 May 19, 2023
d332e64
Update test_pca.py
sply88 May 19, 2023
b9d289d
Update test_pca.py
sply88 May 19, 2023
f8ead24
Update test_pca.py
sply88 May 19, 2023
94f1c1d
Update test_pca.py
sply88 May 19, 2023
11a23a2
Update test_pca.py
sply88 May 19, 2023
f3dcf74
Update test_pca.py
sply88 May 19, 2023
7685689
Update test_pca.py
sply88 May 19, 2023
2c0e456
Update test_pca.py
sply88 May 19, 2023
2c90998
Update test_pca.py
sply88 May 19, 2023
5638524
Update test_pca.py
sply88 May 20, 2023
5b76156
Update test_pca.py
sply88 May 21, 2023
04a3481
Fix test_pca.py
sply88 May 21, 2023
799e5ee
Update test_pca.py
sply88 May 22, 2023
513a97f
Merge branch 'main' into global-random-seed-in-test_pca
sply88 May 22, 2023
b7be87c
Merge branch 'main' into global-random-seed-in-test_pca
sply88 Jul 1, 2023
e3f4217
Merge branch 'main' into global-random-seed-in-test_pca
sply88 Jul 13, 2023
c74872b
Merge branch 'main' into global-random-seed-in-test_pca
sply88 Aug 2, 2023
6b32dca
Empty commit [all random seeds]
sply88 Jul 4, 2024
4162786
Merge branch 'main' into global-random-seed-in-test_pca [all random s…
sply88 Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 90 additions & 64 deletions sklearn/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def test_no_empty_slice_warning():

@pytest.mark.parametrize("copy", [True, False])
@pytest.mark.parametrize("solver", PCA_SOLVERS)
def test_whitening(solver, copy):
def test_whitening(solver, copy, global_random_seed):
# Check that PCA output has unit-variance
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
n_samples = 100
n_features = 80
n_components = 30
Expand All @@ -236,7 +236,7 @@ def test_whitening(solver, copy):
assert X.shape == (n_samples, n_features)

# the component-wise variance is thus highly varying:
assert X.std(axis=0).std() > 43.8
assert X.std(axis=0).std() > 40
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous hardcoded threshold of 43.8 seemed to be tailored to X generated with seed 0. Reduced this a bit to pass for all seeds in the 0-99 range. Comparing with the new threshold of 40 still confirms highly varying variances.


# whiten the data while projecting to the lower dim subspace
X_ = X.copy() # make sure we keep an original across iterations.
Expand All @@ -245,8 +245,9 @@ def test_whitening(solver, copy):
whiten=True,
copy=copy,
svd_solver=solver,
random_state=0,
random_state=global_random_seed,
iterated_power=7,
n_oversamples=13,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adjusting the hardcoded assertion thresholds in this test, a few tests with randomized solver still failed assert_allclose(X_whitened, X_whitened2, rtol=5e-4). Increasing n_oversamples from default 10 to 13 solved this without adjusting the assertion tolerance.

)
# test fit_transform
X_whitened = pca.fit_transform(X_.copy())
Expand All @@ -259,13 +260,17 @@ def test_whitening(solver, copy):

X_ = X.copy()
pca = PCA(
n_components=n_components, whiten=False, copy=copy, svd_solver=solver
n_components=n_components,
whiten=False,
copy=copy,
svd_solver=solver,
random_state=global_random_seed,
).fit(X_.copy())
X_unwhitened = pca.transform(X_)
assert X_unwhitened.shape == (n_samples, n_components)

# in that case the output components still have varying variances
assert X_unwhitened.std(axis=0).std() == pytest.approx(74.1, rel=1e-1)
assert X_unwhitened.std(axis=0).std() > 69
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, previous comparison probably tailored to seed 0 dataset. I lowered the threshold a bit to pass for all seeds in 0-99 range. Also decided to do greater-than instead of approx equality, as the intent is to confirm varying variances.

# we always center, so no test for non-centering.


Expand Down Expand Up @@ -410,17 +415,21 @@ def test_pca_solver_equivalence(


@pytest.mark.parametrize(
"X",
"make_X",
[
np.random.RandomState(0).randn(100, 80),
datasets.make_classification(100, 80, n_informative=78, random_state=0)[0],
np.random.RandomState(0).randn(10, 100),
lambda seed: np.random.RandomState(seed).randn(100, 60),
Copy link
Contributor Author

@sply88 sply88 May 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last assertion in this test makes sure that explained_variance_ is the same as the n_components largest eigenvalues of the covariance matrix of X.
For two values of global_random_seed (17 and 37) this failed with dataset random-data and solver randomized. Since the randomized SVD is an approximation, a little bit of seed-sensitivity is not suprising if we compare it to the exact spectrum of the covariance matrix.
Instead of adjusting the tolerance of the assertion, I slightly reduced the number of features for the random-data setting from 80 to 60.

lambda seed: datasets.make_classification(
100, 80, n_informative=78, random_state=seed
)[0],
lambda seed: np.random.RandomState(seed).randn(10, 100),
],
ids=["random-tall", "correlated-tall", "random-wide"],
)
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_explained_variance_empirical(X, svd_solver):
pca = PCA(n_components=2, svd_solver=svd_solver, random_state=0)
def test_pca_explained_variance_empirical(make_X, svd_solver, global_random_seed):
# parametrized factory make_X and global_random_seed determine dataset X
X = make_X(global_random_seed)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of directly using X in pytest.mark.parametrize, I use the factory function make_X to inject global_random_seed during dataset creation. This is done, because pytest only resolves fixtures within the test function, but not in the decorator.

pca = PCA(n_components=2, svd_solver=svd_solver, random_state=global_random_seed)
X_pca = pca.fit_transform(X)
assert_allclose(pca.explained_variance_, np.var(X_pca, ddof=1, axis=0))

Expand All @@ -430,8 +439,8 @@ def test_pca_explained_variance_empirical(X, svd_solver):


@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"])
def test_pca_singular_values_consistency(svd_solver):
rng = np.random.RandomState(0)
def test_pca_singular_values_consistency(svd_solver, global_random_seed):
rng = np.random.RandomState(global_random_seed)
n_samples, n_features = 100, 80
X = rng.randn(n_samples, n_features)

Expand All @@ -445,8 +454,8 @@ def test_pca_singular_values_consistency(svd_solver):


@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_singular_values(svd_solver):
rng = np.random.RandomState(0)
def test_pca_singular_values(svd_solver, global_random_seed):
rng = np.random.RandomState(global_random_seed)
n_samples, n_features = 100, 80
X = rng.randn(n_samples, n_features)

Expand Down Expand Up @@ -475,9 +484,9 @@ def test_pca_singular_values(svd_solver):


@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_check_projection(svd_solver):
def test_pca_check_projection(svd_solver, global_random_seed):
# Test that the projection of data is correct
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
n, p = 100, 3
X = rng.randn(n, p) * 0.1
X[:10] += np.array([3, 4, 5])
Expand All @@ -490,10 +499,10 @@ def test_pca_check_projection(svd_solver):


@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_check_projection_list(svd_solver):
def test_pca_check_projection_list(svd_solver, global_random_seed):
# Test that the projection of data is correct
X = [[1.0, 0.0], [0.0, 1.0]]
pca = PCA(n_components=1, svd_solver=svd_solver, random_state=0)
pca = PCA(n_components=1, svd_solver=svd_solver, random_state=global_random_seed)
X_trans = pca.fit_transform(X)
assert X_trans.shape, (2, 1)
assert_allclose(X_trans.mean(), 0.00, atol=1e-12)
Expand All @@ -502,20 +511,22 @@ def test_pca_check_projection_list(svd_solver):

@pytest.mark.parametrize("svd_solver", ["full", "arpack", "randomized"])
@pytest.mark.parametrize("whiten", [False, True])
def test_pca_inverse(svd_solver, whiten):
def test_pca_inverse(svd_solver, whiten, global_random_seed):
# Test that the projection of data can be inverted
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
n, p = 50, 3
X = rng.randn(n, p) # spherical data
X[:, 1] *= 0.00001 # make middle component relatively small
X += [5, 4, 3] # make a large mean

# same check that we can find the original data from the transformed
# signal (since the data is almost of rank n_components)
pca = PCA(n_components=2, svd_solver=svd_solver, whiten=whiten).fit(X)
pca = PCA(
n_components=2, svd_solver=svd_solver, whiten=whiten, random_state=rng
).fit(X)
Y = pca.transform(X)
Y_inverse = pca.inverse_transform(Y)
assert_allclose(X, Y_inverse, rtol=5e-6)
assert_allclose(X, Y_inverse, rtol=5e-5)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertion with strict rtol=5e-6 made the test highly seed-sensitive. Before changing the tolerance, 486 of 600 tests failed after including the global_random_seed fixture.
In my opinion, rtol=5e-5 still asserts good recovery of X, since we cannot expect perfect recovery even though the middle column of X has very low variance.



@pytest.mark.parametrize(
Expand Down Expand Up @@ -575,9 +586,9 @@ def test_n_components_none(data, solver, n_components_):


@pytest.mark.parametrize("svd_solver", ["auto", "full"])
def test_n_components_mle(svd_solver):
def test_n_components_mle(svd_solver, global_random_seed):
# Ensure that n_components == 'mle' doesn't raise error for auto/full
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
n_samples, n_features = 600, 10
X = rng.randn(n_samples, n_features)
pca = PCA(n_components="mle", svd_solver=svd_solver)
Expand All @@ -600,22 +611,22 @@ def test_n_components_mle_error(svd_solver):
pca.fit(X)


def test_pca_dim():
def test_pca_dim(global_random_seed):
# Check automated dimensionality setting
rng = np.random.RandomState(0)
n, p = 100, 5
rng = np.random.RandomState(global_random_seed)
n, p = 250, 5
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only adding the global_random_seed fixture showed, that this test was seed-sensitive. Since the MLE cannot be expected to yield the "true" number of component for any dataset, this is not surprising.
Increasing the number of samples makes the test seed-insensitive in the 0-99 range.

X = rng.randn(n, p) * 0.1
X[:10] += np.array([3, 4, 5, 1, 2])
pca = PCA(n_components="mle", svd_solver="full").fit(X)
assert pca.n_components == "mle"
assert pca.n_components_ == 1


def test_infer_dim_1():
def test_infer_dim_1(global_random_seed):
# TODO: explain what this is testing
# Or at least use explicit variable names...
n, p = 1000, 5
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
X = (
rng.randn(n, p) * 0.1
+ rng.randn(n, 1) * np.array([3, 4, 5, 1, 2])
Expand All @@ -628,11 +639,11 @@ def test_infer_dim_1():
assert ll[1] > ll.max() - 0.01 * n


def test_infer_dim_2():
def test_infer_dim_2(global_random_seed):
# TODO: explain what this is testing
# Or at least use explicit variable names...
n, p = 1000, 5
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
X = rng.randn(n, p) * 0.1
X[:10] += np.array([3, 4, 5, 1, 2])
X[10:20] += np.array([6, 0, 7, 2, -1])
Expand All @@ -642,9 +653,9 @@ def test_infer_dim_2():
assert _infer_dimension(spect, n) > 1


def test_infer_dim_3():
def test_infer_dim_3(global_random_seed):
n, p = 100, 5
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
X = rng.randn(n, p) * 0.1
X[:10] += np.array([3, 4, 5, 1, 2])
X[10:20] += np.array([6, 0, 7, 2, -1])
Expand All @@ -671,12 +682,12 @@ def test_infer_dim_by_explained_variance(X, n_components, n_components_validated


@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_score(svd_solver):
def test_pca_score(svd_solver, global_random_seed):
# Test that probabilistic PCA scoring yields a reasonable score
n, p = 1000, 3
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
X = rng.randn(n, p) * 0.1 + np.array([3, 4, 5])
pca = PCA(n_components=2, svd_solver=svd_solver)
pca = PCA(n_components=2, svd_solver=svd_solver, random_state=rng)
pca.fit(X)

ll1 = pca.score(X)
Expand All @@ -686,45 +697,58 @@ def test_pca_score(svd_solver):
ll2 = pca.score(rng.randn(n, p) * 0.2 + np.array([3, 4, 5]))
assert ll1 > ll2

pca = PCA(n_components=2, whiten=True, svd_solver=svd_solver)
pca = PCA(n_components=2, whiten=True, svd_solver=svd_solver, random_state=rng)
pca.fit(X)
ll2 = pca.score(X)
assert ll1 > ll2


def test_pca_score3():
def test_pca_score3(global_random_seed):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just including the global_random_seed fixture made the test highly seed sensitive, as for 24 out of 100 seeds, the test score was highest for the ("wrong") model with two components.
I believe the test is supposed to address the following point: A possible strategy to select an appropriate number of components is to compare the likelihood for held-out data, e.g. in a cross-validation scheme as illustrated in this example. However, we cannot expect the test-likelihood of the "true" model to be maximal for every (isolated) pair of train and test datasets, but only on average. So I changed the test to run 10 trials with different train-test pairs for each n_components and accumulate the test scores.
Too keep the test fast, I also reduced the sample size, as the original sample size of 200 required 15 trials to make the test seed-insensitive in the 0-99 range and with 50 samples only 10 trials are required.

# Check that probabilistic PCA selects the right model
n, p = 200, 3
rng = np.random.RandomState(0)
Xl = rng.randn(n, p) + rng.randn(n, 1) * np.array([3, 4, 5]) + np.array([1, 0, 7])
Xt = rng.randn(n, p) + rng.randn(n, 1) * np.array([3, 4, 5]) + np.array([1, 0, 7])
n, p = 50, 3
rng = np.random.RandomState(global_random_seed)
ll = np.zeros(p)
for k in range(p):
pca = PCA(n_components=k, svd_solver="full")
pca.fit(Xl)
ll[k] = pca.score(Xt)
for _ in range(10):
X_train = (
rng.randn(n, p)
+ rng.randn(n, 1) * np.array([3, 4, 5])
+ np.array([1, 0, 7])
)
X_test = (
rng.randn(n, p)
+ rng.randn(n, 1) * np.array([3, 4, 5])
+ np.array([1, 0, 7])
)
for k in range(p):
pca = PCA(n_components=k, svd_solver="full")
pca.fit(X_train)
ll[k] += pca.score(X_test)

# Across multiple trials, the true data generating model should have
# accumulated the highest test score
assert ll.argmax() == 1


@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_sanity_noise_variance(svd_solver):
def test_pca_sanity_noise_variance(svd_solver, global_random_seed):
# Sanity check for the noise_variance_. For more details see
# https://github.com/scikit-learn/scikit-learn/issues/7568
# https://github.com/scikit-learn/scikit-learn/issues/8541
# https://github.com/scikit-learn/scikit-learn/issues/8544
X, _ = datasets.load_digits(return_X_y=True)
pca = PCA(n_components=30, svd_solver=svd_solver, random_state=0)
pca = PCA(n_components=30, svd_solver=svd_solver, random_state=global_random_seed)
pca.fit(X)
assert np.all((pca.explained_variance_ - pca.noise_variance_) >= 0)


@pytest.mark.parametrize("svd_solver", ["arpack", "randomized"])
def test_pca_score_consistency_solvers(svd_solver):
def test_pca_score_consistency_solvers(svd_solver, global_random_seed):
# Check the consistency of score between solvers
X, _ = datasets.load_digits(return_X_y=True)
pca_full = PCA(n_components=30, svd_solver="full", random_state=0)
pca_other = PCA(n_components=30, svd_solver=svd_solver, random_state=0)
pca_full = PCA(n_components=10, svd_solver="full", random_state=global_random_seed)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test was seed-sensitive for the given assertion tolerance when svd_solver="randomized" was used. Again, I think this can be expected when comparing the randomized solution to the exact solution.
Only comparing the first 10 components (instead of the first 30 components) mitigates this somewhat and the test becomes seed-insensitive in the 0-99 range.

pca_other = PCA(
n_components=10, svd_solver=svd_solver, random_state=global_random_seed
)
pca_full.fit(X)
pca_other.fit(X)
assert_allclose(pca_full.score(X), pca_other.score(X), rtol=5e-6)
Expand Down Expand Up @@ -781,8 +805,8 @@ def test_pca_svd_solver_auto(n_samples, n_features, n_components, expected_solve


@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
def test_pca_deterministic_output(svd_solver):
rng = np.random.RandomState(0)
def test_pca_deterministic_output(svd_solver, global_random_seed):
rng = np.random.RandomState(global_random_seed)
X = rng.rand(10, 10)

transformed_X = np.zeros((20, 2))
Expand Down Expand Up @@ -873,7 +897,7 @@ def test_small_eigenvalues_mle():
assert _infer_dimension(spectrum, 10) == 1


def test_mle_redundant_data():
def test_mle_redundant_data(global_random_seed):
# Test 'mle' with pathological X: only one relevant feature should give a
# rank of 1
X, _ = datasets.make_classification(
Expand All @@ -882,7 +906,7 @@ def test_mle_redundant_data():
n_repeated=18,
n_redundant=1,
n_clusters_per_class=1,
random_state=42,
random_state=global_random_seed,
)
pca = PCA(n_components="mle").fit(X)
assert pca.n_components_ == 1
Expand All @@ -901,11 +925,11 @@ def test_fit_mle_too_few_samples():
pca.fit(X)


def test_mle_simple_case():
def test_mle_simple_case(global_random_seed):
# non-regression test for issue
# https://github.com/scikit-learn/scikit-learn/issues/16730
n_samples, n_dim = 1000, 10
X = np.random.RandomState(0).randn(n_samples, n_dim)
X = np.random.RandomState(global_random_seed).randn(n_samples, n_dim)
X[:, -1] = np.mean(X[:, :-1], axis=-1) # true X dim is ndim - 1
pca_skl = PCA("mle", svd_solver="full")
pca_skl.fit(X)
Expand All @@ -925,14 +949,14 @@ def test_assess_dimesion_rank_one():
assert _assess_dimension(s, rank, n_samples) == -np.inf


def test_pca_randomized_svd_n_oversamples():
def test_pca_randomized_svd_n_oversamples(global_random_seed):
"""Check that exposing and setting `n_oversamples` will provide accurate results
even when `X` as a large number of features.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/20589
"""
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
n_features = 100
X = rng.randn(1_000, n_features)

Expand All @@ -945,7 +969,9 @@ def test_pca_randomized_svd_n_oversamples():
random_state=0,
).fit(X)
pca_full = PCA(n_components=1, svd_solver="full").fit(X)
pca_arpack = PCA(n_components=1, svd_solver="arpack", random_state=0).fit(X)
pca_arpack = PCA(
n_components=1, svd_solver="arpack", random_state=global_random_seed
).fit(X)

assert_allclose(np.abs(pca_full.components_), np.abs(pca_arpack.components_))
assert_allclose(np.abs(pca_randomized.components_), np.abs(pca_arpack.components_))
Expand All @@ -960,9 +986,9 @@ def test_feature_names_out():


@pytest.mark.parametrize("copy", [True, False])
def test_variance_correctness(copy):
def test_variance_correctness(copy, global_random_seed):
"""Check the accuracy of PCA's internal variance calculation"""
rng = np.random.RandomState(0)
rng = np.random.RandomState(global_random_seed)
X = rng.randn(1000, 200)
pca = PCA().fit(X)
pca_var = pca.explained_variance_ / pca.explained_variance_ratio_
Expand Down