Skip to content

Commit 342ac0b

Browse files
Yannick Schwartzmatthew-brett
Yannick Schwartz
authored andcommitted
Added test in test_glm.test_high_level_glm_null_contrasts and fixed a glitch in test_glm.generate_fake_fmri_data
1 parent 799bb80 commit 342ac0b

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

nipy/modalities/fmri/tests/test_glm.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,16 @@ def write_fake_fmri_data(shapes, rk=3, affine=np.eye(4)):
3434

3535

3636
def generate_fake_fmri_data(shapes, rk=3, affine=np.eye(4)):
37-
fmri_data, design_matrices= []
37+
fmri_data = []
38+
design_matrices = []
3839
for i, shape in enumerate(shapes):
3940
data = 100 + np.random.randn(*shape)
4041
data[0] -= 10
4142
fmri_data.append(Nifti1Image(data, affine))
4243
design_matrices.append(np.random.randn(shape[3], rk))
43-
mask = Nifti1Image((np.random.rand(*shape[:3]) > .5).astype(np.int8),
44+
mask = Nifti1Image((np.random.rand(*shape[:3]) > .5).astype(np.int8),
4445
affine)
45-
return mask, fmri_data, design_matrices
46+
return mask, fmri_data, design_matrices
4647

4748

4849
def test_high_level_glm_with_paths():
@@ -105,6 +106,22 @@ def test_high_level_glm_contrasts():
105106
z1.get_data(), z2.get_data())).all())
106107

107108

109+
def test_high_level_glm_null_contrasts():
110+
shapes, rk = ((5, 6, 7, 20), (5, 6, 7, 19)), 3
111+
mask, fmri_data, design_matrices = generate_fake_fmri_data(shapes, rk)
112+
113+
multi_session_model = FMRILinearModel(
114+
fmri_data, design_matrices, mask=None)
115+
multi_session_model.fit()
116+
single_session_model = FMRILinearModel(
117+
fmri_data[:1], design_matrices[:1], mask=None)
118+
single_session_model.fit()
119+
z1, = multi_session_model.contrast([np.eye(rk)[:1]] * 2)
120+
z2, = multi_session_model.contrast([np.eye(rk)[:1], np.zeros((1, rk))])
121+
z3, = single_session_model.contrast([np.eye(rk)[:1]])
122+
np.testing.assert_almost_equal(z2.get_data(), z3.get_data())
123+
124+
108125
def ols_glm(n=100, p=80, q=10):
109126
X, Y = np.random.randn(p, q), np.random.randn(p, n)
110127
glm = GeneralLinearModel(X)

0 commit comments

Comments
 (0)