Skip to content

Commit 422ca44

Browse files
authored
ENH: Improve np.linalg.det performance by simplifying checks (numpy#28649)
* ENH: Improve np.linalg.det performance * Update numpy/linalg/_linalg.py * revert change to complex detection * use suggestion * whitespace * add more small array benchmarks * trigger build
1 parent b7b368a commit 422ca44

File tree

2 files changed

+17
-17
lines changed

2 files changed

+17
-17
lines changed

benchmarks/benchmarks/bench_linalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def time_norm(self, typename):
103103
class LinalgSmallArrays(Benchmark):
104104
""" Test overhead of linalg methods for small arrays """
105105
def setup(self):
106+
self.array_3_3 = np.eye(3) + np.arange(9.).reshape((3, 3))
107+
self.array_3 = np.arange(3.)
106108
self.array_5 = np.arange(5.)
107109
self.array_5_5 = np.reshape(np.arange(25.), (5, 5))
108110

@@ -112,6 +114,16 @@ def time_norm_small_array(self):
112114
def time_det_small_array(self):
113115
np.linalg.det(self.array_5_5)
114116

117+
def time_det_3x3(self):
118+
np.linalg.det(self.array_3_3)
119+
120+
def time_solve_3x3(self):
121+
np.linalg.solve(self.array_3_3, self.array_3)
122+
123+
def time_eig_3x3(self):
124+
np.linalg.eig(self.array_3_3)
125+
126+
115127
class Lstsq(Benchmark):
116128
def setup(self):
117129
self.a = get_squares_()['float64']

numpy/linalg/_linalg.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ def _assert_stacked_2d(*arrays):
197197

198198
def _assert_stacked_square(*arrays):
199199
for a in arrays:
200-
m, n = a.shape[-2:]
200+
try:
201+
m, n = a.shape[-2:]
202+
except ValueError:
203+
raise LinAlgError('%d-dimensional array given. Array must be '
204+
'at least two-dimensional' % a.ndim)
201205
if m != n:
202206
raise LinAlgError('Last 2 dimensions of the array must be square')
203207

@@ -392,7 +396,6 @@ def solve(a, b):
392396
393397
"""
394398
a, _ = _makearray(a)
395-
_assert_stacked_2d(a)
396399
_assert_stacked_square(a)
397400
b, wrap = _makearray(b)
398401
t, result_t = _commonType(a, b)
@@ -599,7 +602,6 @@ def inv(a):
599602
600603
"""
601604
a, wrap = _makearray(a)
602-
_assert_stacked_2d(a)
603605
_assert_stacked_square(a)
604606
t, result_t = _commonType(a)
605607

@@ -681,7 +683,6 @@ def matrix_power(a, n):
681683
682684
"""
683685
a = asanyarray(a)
684-
_assert_stacked_2d(a)
685686
_assert_stacked_square(a)
686687

687688
try:
@@ -830,7 +831,6 @@ def cholesky(a, /, *, upper=False):
830831
"""
831832
gufunc = _umath_linalg.cholesky_up if upper else _umath_linalg.cholesky_lo
832833
a, wrap = _makearray(a)
833-
_assert_stacked_2d(a)
834834
_assert_stacked_square(a)
835835
t, result_t = _commonType(a)
836836
signature = 'D->D' if isComplexType(t) else 'd->d'
@@ -1201,7 +1201,6 @@ def eigvals(a):
12011201
12021202
"""
12031203
a, wrap = _makearray(a)
1204-
_assert_stacked_2d(a)
12051204
_assert_stacked_square(a)
12061205
_assert_finite(a)
12071206
t, result_t = _commonType(a)
@@ -1310,7 +1309,6 @@ def eigvalsh(a, UPLO='L'):
13101309
gufunc = _umath_linalg.eigvalsh_up
13111310

13121311
a, wrap = _makearray(a)
1313-
_assert_stacked_2d(a)
13141312
_assert_stacked_square(a)
13151313
t, result_t = _commonType(a)
13161314
signature = 'D->d' if isComplexType(t) else 'd->d'
@@ -1320,11 +1318,6 @@ def eigvalsh(a, UPLO='L'):
13201318
w = gufunc(a, signature=signature)
13211319
return w.astype(_realType(result_t), copy=False)
13221320

1323-
def _convertarray(a):
1324-
t, result_t = _commonType(a)
1325-
a = a.astype(t).T.copy()
1326-
return a, t, result_t
1327-
13281321

13291322
# Eigenvectors
13301323

@@ -1461,7 +1454,6 @@ def eig(a):
14611454
14621455
"""
14631456
a, wrap = _makearray(a)
1464-
_assert_stacked_2d(a)
14651457
_assert_stacked_square(a)
14661458
_assert_finite(a)
14671459
t, result_t = _commonType(a)
@@ -1612,7 +1604,6 @@ def eigh(a, UPLO='L'):
16121604
raise ValueError("UPLO argument must be 'L' or 'U'")
16131605

16141606
a, wrap = _makearray(a)
1615-
_assert_stacked_2d(a)
16161607
_assert_stacked_square(a)
16171608
t, result_t = _commonType(a)
16181609

@@ -1978,7 +1969,6 @@ def cond(x, p=None):
19781969
else:
19791970
# Call inv(x) ignoring errors. The result array will
19801971
# contain nans in the entries where inversion failed.
1981-
_assert_stacked_2d(x)
19821972
_assert_stacked_square(x)
19831973
t, result_t = _commonType(x)
19841974
signature = 'D->D' if isComplexType(t) else 'd->d'
@@ -2318,7 +2308,6 @@ def slogdet(a):
23182308
23192309
"""
23202310
a = asarray(a)
2321-
_assert_stacked_2d(a)
23222311
_assert_stacked_square(a)
23232312
t, result_t = _commonType(a)
23242313
real_t = _realType(result_t)
@@ -2377,7 +2366,6 @@ def det(a):
23772366
23782367
"""
23792368
a = asarray(a)
2380-
_assert_stacked_2d(a)
23812369
_assert_stacked_square(a)
23822370
t, result_t = _commonType(a)
23832371
signature = 'D->D' if isComplexType(t) else 'd->d'

0 commit comments

Comments
 (0)