Skip to content

Commit 62a0bcd

Browse files
ogriselNelleV
authored andcommitted
Better fix the rng seed in test_fastica_simple (#13848)
1 parent c28ef9e commit 62a0bcd

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

sklearn/decomposition/tests/test_fastica.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import itertools
55
import warnings
6+
import pytest
67

78
import numpy as np
89
from scipy import stats
@@ -50,9 +51,11 @@ def test_gs():
5051
assert_less((tmp[:5] ** 2).sum(), 1.e-10)
5152

5253

53-
def test_fastica_simple(add_noise=False):
54+
@pytest.mark.parametrize("add_noise", [True, False])
55+
@pytest.mark.parametrize("seed", range(1))
56+
def test_fastica_simple(add_noise, seed):
5457
# Test the FastICA algorithm on very simple data.
55-
rng = np.random.RandomState(0)
58+
rng = np.random.RandomState(seed)
5659
# scipy.stats uses the global RNG:
5760
n_samples = 1000
5861
# Generate two sources:
@@ -82,12 +85,15 @@ def g_test(x):
8285
whitening = [True, False]
8386
for algo, nl, whiten in itertools.product(algos, nls, whitening):
8487
if whiten:
85-
k_, mixing_, s_ = fastica(m.T, fun=nl, algorithm=algo)
88+
k_, mixing_, s_ = fastica(m.T, fun=nl, algorithm=algo,
89+
random_state=rng)
8690
assert_raises(ValueError, fastica, m.T, fun=np.tanh,
8791
algorithm=algo)
8892
else:
89-
X = PCA(n_components=2, whiten=True).fit_transform(m.T)
90-
k_, mixing_, s_ = fastica(X, fun=nl, algorithm=algo, whiten=False)
93+
pca = PCA(n_components=2, whiten=True, random_state=rng)
94+
X = pca.fit_transform(m.T)
95+
k_, mixing_, s_ = fastica(X, fun=nl, algorithm=algo, whiten=False,
96+
random_state=rng)
9197
assert_raises(ValueError, fastica, X, fun=np.tanh,
9298
algorithm=algo)
9399
s_ = s_.T
@@ -113,8 +119,9 @@ def g_test(x):
113119
assert_almost_equal(np.dot(s2_, s2) / n_samples, 1, decimal=1)
114120

115121
# Test FastICA class
116-
_, _, sources_fun = fastica(m.T, fun=nl, algorithm=algo, random_state=0)
117-
ica = FastICA(fun=nl, algorithm=algo, random_state=0)
122+
_, _, sources_fun = fastica(m.T, fun=nl, algorithm=algo,
123+
random_state=seed)
124+
ica = FastICA(fun=nl, algorithm=algo, random_state=seed)
118125
sources = ica.fit_transform(m.T)
119126
assert_equal(ica.components_.shape, (2, 2))
120127
assert_equal(sources.shape, (1000, 2))
@@ -125,7 +132,7 @@ def g_test(x):
125132
assert_equal(ica.mixing_.shape, (2, 2))
126133

127134
for fn in [np.tanh, "exp(-.5(x^2))"]:
128-
ica = FastICA(fun=fn, algorithm=algo, random_state=0)
135+
ica = FastICA(fun=fn, algorithm=algo)
129136
assert_raises(ValueError, ica.fit, m.T)
130137

131138
assert_raises(TypeError, FastICA(fun=range(10)).fit, m.T)

0 commit comments

Comments
 (0)