Skip to content

Commit d915f89

Browse files
[BF] Fix bug in multiple_fast_inv.
Some computations were done outside the loop, but they depend on values changing at each iteration.
1 parent a248951 commit d915f89

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

nipy/algorithms/statistics/utils.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,31 @@ def multiple_fast_inv(a):
4343
from scipy.linalg import calc_lwork
4444
from scipy.linalg.lapack import get_lapack_funcs
4545
a1, n = a[0], a.shape[0]
46-
getrf, getri = get_lapack_funcs(('getrf','getri'), (a1,))
47-
if getrf.module_name[:7] == 'clapack' != getri.module_name[:7]:
48-
# ATLAS 3.2.1 has getrf but not getri.
49-
for i in range(n):
46+
getrf, getri = get_lapack_funcs(('getrf', 'getri'), (a1,))
47+
for i in range(n):
48+
if getrf.module_name[:7] == 'clapack' != getri.module_name[:7]:
49+
# ATLAS 3.2.1 has getrf but not getri.
5050
lu, piv, info = getrf(np.transpose(a[i]), rowmajor=0,
5151
overwrite_a=True)
5252
a[i] = np.transpose(lu)
53-
else:
54-
for i in range(n):
53+
else:
5554
a[i], piv, info = getrf(a[i], overwrite_a=True)
56-
if info == 0:
57-
if getri.module_name[:7] == 'flapack':
58-
lwork = calc_lwork.getri(getri.prefix, a1.shape[0])
59-
lwork = lwork[1]
60-
# XXX: the following line fixes curious SEGFAULT when
61-
# benchmarking 500x500 matrix inverse. This seems to
62-
# be a bug in LAPACK ?getri routine because if lwork is
63-
# minimal (when using lwork[0] instead of lwork[1]) then
64-
# all tests pass. Further investigation is required if
65-
# more such SEGFAULTs occur.
66-
lwork = int(1.01 * lwork)
67-
for i in range(n):
55+
if info == 0:
56+
if getri.module_name[:7] == 'flapack':
57+
lwork = calc_lwork.getri(getri.prefix, a1.shape[0])
58+
lwork = lwork[1]
59+
# XXX: the following line fixes curious SEGFAULT when
60+
# benchmarking 500x500 matrix inverse. This seems to
61+
# be a bug in LAPACK ?getri routine because if lwork is
62+
# minimal (when using lwork[0] instead of lwork[1]) then
63+
# all tests pass. Further investigation is required if
64+
# more such SEGFAULTs occur.
65+
lwork = int(1.01 * lwork)
6866
a[i], _ = getri(a[i], piv, lwork=lwork, overwrite_lu=1)
69-
else: # clapack
70-
for i in range(n):
67+
else: # clapack
7168
a[i], _ = getri(a[i], piv, overwrite_lu=1)
69+
else:
70+
raise ValueError('Matrix LU decomposition failed')
7271
return a
7372

7473

0 commit comments

Comments
 (0)