Skip to content

Commit 6d66d0f

Browse files
veneamueller
authored andcommitted
Cloned @jakevdp's pinvh tests
1 parent 12f4eb9 commit 6d66d0f

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

sklearn/utils/tests/test_utils.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,27 @@ def test_safe_mask():
9393
assert_equal(X_csr[mask].shape[0], 3)
9494

9595

96-
def test_pinvh():
97-
a = np.random.randn(5, 3)
98-
a = np.dot(a, a.T) # symmetric singular matrix
99-
assert_almost_equal(pinv2(a), pinvh(a))
96+
def test_pinvh_simple_real():
97+
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=np.float64)
98+
a = np.dot(a, a.T)
99+
a_pinv = pinvh(a)
100+
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))
101+
102+
103+
def test_pinvh_nonpositive():
104+
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
105+
a = np.dot(a, a.T)
106+
u, s, vt = np.linalg.svd(a)
107+
s[0] *= -1
108+
a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
109+
a_pinv = pinv2(a)
110+
a_pinvh = pinvh(a)
111+
assert_almost_equal(a_pinv, a_pinvh)
112+
113+
114+
def test_pinvh_simple_complex():
115+
a = (np.array([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
116+
+ 1j * np.array([[10, 8, 7], [6, 5, 4], [3, 2, 1]]))
117+
a = np.dot(a, a.conj().T)
118+
a_pinv = pinvh(a)
119+
assert_almost_equal(np.dot(a, a_pinv), np.eye(3))

0 commit comments

Comments
 (0)