3
3
"""
4
4
import itertools
5
5
import warnings
6
+ import pytest
6
7
7
8
import numpy as np
8
9
from scipy import stats
@@ -50,9 +51,11 @@ def test_gs():
50
51
assert_less ((tmp [:5 ] ** 2 ).sum (), 1.e-10 )
51
52
52
53
53
- def test_fastica_simple (add_noise = False ):
54
+ @pytest .mark .parametrize ("add_noise" , [True , False ])
55
+ @pytest .mark .parametrize ("seed" , range (1 ))
56
+ def test_fastica_simple (add_noise , seed ):
54
57
# Test the FastICA algorithm on very simple data.
55
- rng = np .random .RandomState (0 )
58
+ rng = np .random .RandomState (seed )
56
59
# scipy.stats uses the global RNG:
57
60
n_samples = 1000
58
61
# Generate two sources:
@@ -82,12 +85,15 @@ def g_test(x):
82
85
whitening = [True , False ]
83
86
for algo , nl , whiten in itertools .product (algos , nls , whitening ):
84
87
if whiten :
85
- k_ , mixing_ , s_ = fastica (m .T , fun = nl , algorithm = algo )
88
+ k_ , mixing_ , s_ = fastica (m .T , fun = nl , algorithm = algo ,
89
+ random_state = rng )
86
90
assert_raises (ValueError , fastica , m .T , fun = np .tanh ,
87
91
algorithm = algo )
88
92
else :
89
- X = PCA (n_components = 2 , whiten = True ).fit_transform (m .T )
90
- k_ , mixing_ , s_ = fastica (X , fun = nl , algorithm = algo , whiten = False )
93
+ pca = PCA (n_components = 2 , whiten = True , random_state = rng )
94
+ X = pca .fit_transform (m .T )
95
+ k_ , mixing_ , s_ = fastica (X , fun = nl , algorithm = algo , whiten = False ,
96
+ random_state = rng )
91
97
assert_raises (ValueError , fastica , X , fun = np .tanh ,
92
98
algorithm = algo )
93
99
s_ = s_ .T
@@ -113,8 +119,9 @@ def g_test(x):
113
119
assert_almost_equal (np .dot (s2_ , s2 ) / n_samples , 1 , decimal = 1 )
114
120
115
121
# Test FastICA class
116
- _ , _ , sources_fun = fastica (m .T , fun = nl , algorithm = algo , random_state = 0 )
117
- ica = FastICA (fun = nl , algorithm = algo , random_state = 0 )
122
+ _ , _ , sources_fun = fastica (m .T , fun = nl , algorithm = algo ,
123
+ random_state = seed )
124
+ ica = FastICA (fun = nl , algorithm = algo , random_state = seed )
118
125
sources = ica .fit_transform (m .T )
119
126
assert_equal (ica .components_ .shape , (2 , 2 ))
120
127
assert_equal (sources .shape , (1000 , 2 ))
@@ -125,7 +132,7 @@ def g_test(x):
125
132
assert_equal (ica .mixing_ .shape , (2 , 2 ))
126
133
127
134
for fn in [np .tanh , "exp(-.5(x^2))" ]:
128
- ica = FastICA (fun = fn , algorithm = algo , random_state = 0 )
135
+ ica = FastICA (fun = fn , algorithm = algo )
129
136
assert_raises (ValueError , ica .fit , m .T )
130
137
131
138
assert_raises (TypeError , FastICA (fun = range (10 )).fit , m .T )
0 commit comments