Skip to content

Commit f332367

Browse files
committed
Merge pull request #249 from VirgileFritsch/master
BF: Fix bug in multiple_fast_inv Some computations were done outside the loop, but they depend on values changing at each iteration. This resulted in wrong inverse values because of the use of the wrong piv variable. Also the check of info was done only for the last decomposition.
2 parents 5ce0c85 + cb53c6f commit f332367

File tree

2 files changed

+31
-21
lines changed

2 files changed

+31
-21
lines changed

nipy/algorithms/statistics/tests/test_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy.stats import norm
55

6-
from ..utils import multiple_mahalanobis, z_score
6+
from ..utils import multiple_mahalanobis, z_score, multiple_fast_inv
77
from nose.tools import assert_true
88
from numpy.testing import assert_almost_equal, assert_array_almost_equal
99

@@ -31,6 +31,16 @@ def test_mahalanobis2():
3131
f_mah = (multiple_mahalanobis(x, Aa))[i]
3232
assert_true(np.allclose(mah, f_mah))
3333

34+
def test_multiple_fast_inv():
35+
shape = (10, 20, 20)
36+
X = np.random.randn(shape[0], shape[1], shape[2])
37+
X_inv_ref = np.zeros(shape)
38+
for i in range(shape[0]):
39+
X[i] = np.dot(X[i], X[i].T)
40+
X_inv_ref[i] = np.linalg.inv(X[i])
41+
X_inv = multiple_fast_inv(X)
42+
assert_almost_equal(X_inv_ref, X_inv)
43+
3444

3545
if __name__ == "__main__":
3646
import nose

nipy/algorithms/statistics/utils.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,32 @@ def multiple_fast_inv(a):
4343
from scipy.linalg import calc_lwork
4444
from scipy.linalg.lapack import get_lapack_funcs
4545
a1, n = a[0], a.shape[0]
46-
getrf, getri = get_lapack_funcs(('getrf','getri'), (a1,))
47-
if getrf.module_name[:7] == 'clapack' != getri.module_name[:7]:
48-
# ATLAS 3.2.1 has getrf but not getri.
49-
for i in range(n):
46+
getrf, getri = get_lapack_funcs(('getrf', 'getri'), (a1,))
47+
for i in range(n):
48+
if (getrf.module_name[:7] == 'clapack'
49+
and getri.module_name[:7] != 'clapack'):
50+
# ATLAS 3.2.1 has getrf but not getri.
5051
lu, piv, info = getrf(np.transpose(a[i]), rowmajor=0,
5152
overwrite_a=True)
5253
a[i] = np.transpose(lu)
53-
else:
54-
for i in range(n):
54+
else:
5555
a[i], piv, info = getrf(a[i], overwrite_a=True)
56-
if info == 0:
57-
if getri.module_name[:7] == 'flapack':
58-
lwork = calc_lwork.getri(getri.prefix, a1.shape[0])
59-
lwork = lwork[1]
60-
# XXX: the following line fixes curious SEGFAULT when
61-
# benchmarking 500x500 matrix inverse. This seems to
62-
# be a bug in LAPACK ?getri routine because if lwork is
63-
# minimal (when using lwork[0] instead of lwork[1]) then
64-
# all tests pass. Further investigation is required if
65-
# more such SEGFAULTs occur.
66-
lwork = int(1.01 * lwork)
67-
for i in range(n):
56+
if info == 0:
57+
if getri.module_name[:7] == 'flapack':
58+
lwork = calc_lwork.getri(getri.prefix, a1.shape[0])
59+
lwork = lwork[1]
60+
# XXX: the following line fixes curious SEGFAULT when
61+
# benchmarking 500x500 matrix inverse. This seems to
62+
# be a bug in LAPACK ?getri routine because if lwork is
63+
# minimal (when using lwork[0] instead of lwork[1]) then
64+
# all tests pass. Further investigation is required if
65+
# more such SEGFAULTs occur.
66+
lwork = int(1.01 * lwork)
6867
a[i], _ = getri(a[i], piv, lwork=lwork, overwrite_lu=1)
69-
else: # clapack
70-
for i in range(n):
68+
else: # clapack
7169
a[i], _ = getri(a[i], piv, overwrite_lu=1)
70+
else:
71+
raise ValueError('Matrix LU decomposition failed')
7272
return a
7373

7474

0 commit comments

Comments
 (0)