Skip to content

Commit f7be36b

Browse files
authored
Merge pull request #8827 from eric-wieser/fix-pinv
BUG: Fix pinv for stacked matrices
2 parents 03f3789 + ebe2cfb commit f7be36b

File tree

3 files changed

+71
-35
lines changed

3 files changed

+71
-35
lines changed

doc/release/1.14.0-notes.rst

+4
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ selected via the ``--fcompiler`` and ``--compiler`` options to
217217
supported; by default a gfortran-compatible static archive
218218
``openblas.a`` is looked for.
219219

220+
``np.linalg.pinv`` now works on stacked matrices
221+
------------------------------------------------
222+
Previously it was limited to a single 2d array.
223+
220224
``numpy.save`` aligns data to 64 bytes instead of 16
221225
----------------------------------------------------
222226
Saving NumPy arrays in the ``npy`` format with ``numpy.save`` inserts

numpy/linalg/linalg.py

+61-33
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
import warnings
2020

2121
from numpy.core import (
22-
array, asarray, zeros, empty, empty_like, transpose, intc, single, double,
22+
array, asarray, zeros, empty, empty_like, intc, single, double,
2323
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
2424
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
2525
finfo, errstate, geterrobj, longdouble, moveaxis, amin, amax, product, abs,
26-
broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones
27-
)
26+
broadcast, atleast_2d, intp, asanyarray, isscalar, object_, ones, matmul,
27+
swapaxes, divide)
28+
2829
from numpy.core.multiarray import normalize_axis_index
2930
from numpy.lib import triu, asfarray
3031
from numpy.linalg import lapack_lite, _umath_linalg
@@ -223,6 +224,22 @@ def _assertNoEmpty2d(*arrays):
223224
if _isEmpty2d(a):
224225
raise LinAlgError("Arrays cannot be empty")
225226

227+
def transpose(a):
228+
"""
229+
Transpose each matrix in a stack of matrices.
230+
231+
Unlike np.transpose, this only swaps the last two axes, rather than all of
232+
them
233+
234+
Parameters
235+
----------
236+
a : (...,M,N) array_like
237+
238+
Returns
239+
-------
240+
aT : (...,N,M) ndarray
241+
"""
242+
return swapaxes(a, -1, -2)
226243

227244
# Linear equations
228245

@@ -1279,7 +1296,7 @@ def eigh(a, UPLO='L'):
12791296

12801297
# Singular value decomposition
12811298

1282-
def svd(a, full_matrices=1, compute_uv=1):
1299+
def svd(a, full_matrices=True, compute_uv=True):
12831300
"""
12841301
Singular Value Decomposition.
12851302
@@ -1494,15 +1511,21 @@ def matrix_rank(M, tol=None):
14941511
Rank of the array is the number of SVD singular values of the array that are
14951512
greater than `tol`.
14961513
1514+
.. versionchanged:: 1.14
1515+
Can now operate on stacks of matrices
1516+
14971517
Parameters
14981518
----------
14991519
M : {(M,), (..., M, N)} array_like
15001520
input vector or stack of matrices
1501-
tol : {None, float}, optional
1502-
threshold below which SVD values are considered zero. If `tol` is
1503-
None, and ``S`` is an array with singular values for `M`, and
1504-
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
1505-
set to ``S.max() * max(M.shape) * eps``.
1521+
tol : (...) array_like, float, optional
1522+
threshold below which SVD values are considered zero. If `tol` is
1523+
None, and ``S`` is an array with singular values for `M`, and
1524+
``eps`` is the epsilon value for datatype of ``S``, then `tol` is
1525+
set to ``S.max() * max(M.shape) * eps``.
1526+
1527+
.. versionchanged:: 1.14
1528+
Broadcasted against the stack of matrices
15061529
15071530
Notes
15081531
-----
@@ -1569,6 +1592,8 @@ def matrix_rank(M, tol=None):
15691592
S = svd(M, compute_uv=False)
15701593
if tol is None:
15711594
tol = S.max(axis=-1, keepdims=True) * max(M.shape[-2:]) * finfo(S.dtype).eps
1595+
else:
1596+
tol = asarray(tol)[...,newaxis]
15721597
return (S > tol).sum(axis=-1)
15731598

15741599

@@ -1582,26 +1607,29 @@ def pinv(a, rcond=1e-15 ):
15821607
singular-value decomposition (SVD) and including all
15831608
*large* singular values.
15841609
1610+
.. versionchanged:: 1.14
1611+
Can now operate on stacks of matrices
1612+
15851613
Parameters
15861614
----------
1587-
a : (M, N) array_like
1588-
Matrix to be pseudo-inverted.
1589-
rcond : float
1590-
Cutoff for small singular values.
1591-
Singular values smaller (in modulus) than
1592-
`rcond` * largest_singular_value (again, in modulus)
1593-
are set to zero.
1615+
a : (..., M, N) array_like
1616+
Matrix or stack of matrices to be pseudo-inverted.
1617+
rcond : (...) array_like of float
1618+
Cutoff for small singular values.
1619+
Singular values smaller (in modulus) than
1620+
`rcond` * largest_singular_value (again, in modulus)
1621+
are set to zero. Broadcasts against the stack of matrices
15941622
15951623
Returns
15961624
-------
1597-
B : (N, M) ndarray
1598-
The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so
1599-
is `B`.
1625+
B : (..., N, M) ndarray
1626+
The pseudo-inverse of `a`. If `a` is a `matrix` instance, then so
1627+
is `B`.
16001628
16011629
Raises
16021630
------
16031631
LinAlgError
1604-
If the SVD computation does not converge.
1632+
If the SVD computation does not converge.
16051633
16061634
Notes
16071635
-----
@@ -1638,20 +1666,20 @@ def pinv(a, rcond=1e-15 ):
16381666
16391667
"""
16401668
a, wrap = _makearray(a)
1669+
rcond = asarray(rcond)
16411670
if _isEmpty2d(a):
16421671
res = empty(a.shape[:-2] + (a.shape[-1], a.shape[-2]), dtype=a.dtype)
16431672
return wrap(res)
16441673
a = a.conjugate()
1645-
u, s, vt = svd(a, 0)
1646-
m = u.shape[0]
1647-
n = vt.shape[1]
1648-
cutoff = rcond*maximum.reduce(s)
1649-
for i in range(min(n, m)):
1650-
if s[i] > cutoff:
1651-
s[i] = 1./s[i]
1652-
else:
1653-
s[i] = 0.
1654-
res = dot(transpose(vt), multiply(s[:, newaxis], transpose(u)))
1674+
u, s, vt = svd(a, full_matrices=False)
1675+
1676+
# discard small singular values
1677+
cutoff = rcond[..., newaxis] * amax(s, axis=-1, keepdims=True)
1678+
large = s > cutoff
1679+
s = divide(1, s, where=large, out=s)
1680+
s[~large] = 0
1681+
1682+
res = matmul(transpose(vt), multiply(s[..., newaxis], transpose(u)))
16551683
return wrap(res)
16561684

16571685
# Determinant
@@ -1987,13 +2015,13 @@ def lstsq(a, b, rcond="warn"):
19872015
resids = array([sum((ravel(bstar)[n:])**2)],
19882016
dtype=result_real_t)
19892017
else:
1990-
x = array(transpose(bstar)[:n,:], dtype=result_t, copy=True)
2018+
x = array(bstar.T[:n,:], dtype=result_t, copy=True)
19912019
if results['rank'] == n and m > n:
19922020
if isComplexType(t):
1993-
resids = sum(abs(transpose(bstar)[n:,:])**2, axis=0).astype(
2021+
resids = sum(abs(bstar.T[n:,:])**2, axis=0).astype(
19942022
result_real_t, copy=False)
19952023
else:
1996-
resids = sum((transpose(bstar)[n:,:])**2, axis=0).astype(
2024+
resids = sum((bstar.T[n:,:])**2, axis=0).astype(
19972025
result_real_t, copy=False)
19982026

19992027
st = s[:min(n, m)].astype(result_real_t, copy=True)

numpy/linalg/tests/test_linalg.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -712,12 +712,16 @@ def test(self):
712712
assert_almost_equal(linalg.cond(A, inf), 3.)
713713

714714

715-
class TestPinv(LinalgSquareTestCase, LinalgNonsquareTestCase):
715+
class TestPinv(LinalgSquareTestCase,
716+
LinalgNonsquareTestCase,
717+
LinalgGeneralizedSquareTestCase,
718+
LinalgGeneralizedNonsquareTestCase):
716719

717720
def do(self, a, b, tags):
718721
a_ginv = linalg.pinv(a)
719722
# `a @ a_ginv == I` does not hold if a is singular
720-
assert_almost_equal(dot(a, a_ginv).dot(a), a, single_decimal=5, double_decimal=11)
723+
dot = dot_generalized
724+
assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
721725
assert_(imply(isinstance(a, matrix), isinstance(a_ginv, matrix)))
722726

723727

0 commit comments

Comments
 (0)