@@ -43,32 +43,31 @@ def multiple_fast_inv(a):
43
43
from scipy .linalg import calc_lwork
44
44
from scipy .linalg .lapack import get_lapack_funcs
45
45
a1 , n = a [0 ], a .shape [0 ]
46
- getrf , getri = get_lapack_funcs (('getrf' ,'getri' ), (a1 ,))
47
- if getrf . module_name [: 7 ] == 'clapack' != getri . module_name [: 7 ] :
48
- # ATLAS 3.2.1 has getrf but not getri.
49
- for i in range ( n ):
46
+ getrf , getri = get_lapack_funcs (('getrf' , 'getri' ), (a1 ,))
47
+ for i in range ( n ) :
48
+ if getrf . module_name [: 7 ] == 'clapack' != getri .module_name [: 7 ]:
49
+ # ATLAS 3.2.1 has getrf but not getri.
50
50
lu , piv , info = getrf (np .transpose (a [i ]), rowmajor = 0 ,
51
51
overwrite_a = True )
52
52
a [i ] = np .transpose (lu )
53
- else :
54
- for i in range (n ):
53
+ else :
55
54
a [i ], piv , info = getrf (a [i ], overwrite_a = True )
56
- if info == 0 :
57
- if getri .module_name [:7 ] == 'flapack' :
58
- lwork = calc_lwork .getri (getri .prefix , a1 .shape [0 ])
59
- lwork = lwork [1 ]
60
- # XXX: the following line fixes curious SEGFAULT when
61
- # benchmarking 500x500 matrix inverse. This seems to
62
- # be a bug in LAPACK ?getri routine because if lwork is
63
- # minimal (when using lwork[0] instead of lwork[1]) then
64
- # all tests pass. Further investigation is required if
65
- # more such SEGFAULTs occur.
66
- lwork = int (1.01 * lwork )
67
- for i in range (n ):
55
+ if info == 0 :
56
+ if getri .module_name [:7 ] == 'flapack' :
57
+ lwork = calc_lwork .getri (getri .prefix , a1 .shape [0 ])
58
+ lwork = lwork [1 ]
59
+ # XXX: the following line fixes curious SEGFAULT when
60
+ # benchmarking 500x500 matrix inverse. This seems to
61
+ # be a bug in LAPACK ?getri routine because if lwork is
62
+ # minimal (when using lwork[0] instead of lwork[1]) then
63
+ # all tests pass. Further investigation is required if
64
+ # more such SEGFAULTs occur.
65
+ lwork = int (1.01 * lwork )
68
66
a [i ], _ = getri (a [i ], piv , lwork = lwork , overwrite_lu = 1 )
69
- else : # clapack
70
- for i in range (n ):
67
+ else : # clapack
71
68
a [i ], _ = getri (a [i ], piv , overwrite_lu = 1 )
69
+ else :
70
+ raise ValueError ('Matrix LU decomposition failed' )
72
71
return a
73
72
74
73
0 commit comments