Skip to content

FIX randomized_svd for complex valued input #30737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.utils/30737.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`utils.extmath.randomized_svd` now handles complex valued inputs. By
:user:`Connor Lane <clane9>`.
18 changes: 18 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,3 +1108,21 @@ def _tolist(array, xp=None):
return array.tolist()
array_np = _convert_to_numpy(array, xp=xp)
return [element.item() for element in array_np]


def _conj_transpose(array, xp=None):
"""Return the matrix transpose, or conjugate transpose for complex input.

Array API compliant version of `array.conj().T`.
"""
xp, _ = get_namespace(array, xp=xp)
return xp.conj(array).T if _iscomplexobj(array, xp=xp) else array.T


def _iscomplexobj(array, xp=None):
"""Check if an array is complex valued.

Array API compliant version of `np.iscomplexobj()`.
"""
xp, _ = get_namespace(array, xp=xp)
return hasattr(array, "dtype") and xp.isdtype(array.dtype, kind="complex floating")
30 changes: 23 additions & 7 deletions sklearn/utils/extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
from scipy import linalg, sparse

from ..utils._param_validation import Interval, StrOptions, validate_params
from ._array_api import _average, _is_numpy_namespace, _nanmean, device, get_namespace
from ._array_api import (
_average,
_conj_transpose,
_is_numpy_namespace,
_iscomplexobj,
_nanmean,
device,
get_namespace,
)
from .sparsefuncs_fast import csr_row_norms
from .validation import check_array, check_random_state

Expand Down Expand Up @@ -333,7 +341,8 @@ def randomized_range_finder(
# singular vectors of A in Q
for _ in range(n_iter):
Q, _ = normalizer(A @ Q)
Q, _ = normalizer(A.T @ Q)
# Conjugate transpose for complex input (normal transpose for real)
Q, _ = normalizer(_conj_transpose(A, xp=xp) @ Q)

# Sample the range of A using by linear projection of Q
# Extract an orthonormal basis
Expand Down Expand Up @@ -506,6 +515,7 @@ def randomized_svd(
sparse.SparseEfficiencyWarning,
)

xp, is_array_api_compliant = get_namespace(M)
random_state = check_random_state(random_state)
n_random = n_components + n_oversamples
n_samples, n_features = M.shape
Expand All @@ -519,7 +529,8 @@ def randomized_svd(
transpose = n_samples < n_features
if transpose:
# this implementation is a bit faster with smaller shape[1]
M = M.T
# Conjugate transpose for complex input (normal transpose for real)
M = _conj_transpose(M, xp=xp)

Q = randomized_range_finder(
M,
Expand All @@ -530,10 +541,9 @@ def randomized_svd(
)

# project M to the (k + p) dimensional space using the basis vectors
B = Q.T @ M
B = _conj_transpose(Q, xp=xp) @ M

# compute the SVD on the thin matrix: (k + p) wide
xp, is_array_api_compliant = get_namespace(B)
if is_array_api_compliant:
Uhat, s, Vt = xp.linalg.svd(B, full_matrices=False)
else:
Expand All @@ -546,7 +556,9 @@ def randomized_svd(
del B
U = Q @ Uhat

if flip_sign:
# can't flip sign for complex valued input, since complex svd is unique only up to
# phase shifts.
if flip_sign and not _iscomplexobj(M, xp=xp):
if not transpose:
U, Vt = svd_flip(U, Vt)
else:
Expand All @@ -556,7 +568,11 @@ def randomized_svd(

if transpose:
# transpose back the results according to the input convention
return Vt[:n_components, :].T, s[:n_components], U[:, :n_components].T
return (
_conj_transpose(Vt[:n_components, :], xp=xp),
s[:n_components],
_conj_transpose(U[:, :n_components], xp=xp),
)
else:
return U[:, :n_components], s[:n_components], Vt[:n_components, :]

Expand Down
55 changes: 55 additions & 0 deletions sklearn/utils/tests/test_extmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,61 @@ def test_randomized_svd_lapack_driver(n, m, k, seed):
assert_allclose(vt1, vt2, atol=0, rtol=1e-3)


def test_randomized_svd_complex():
# Check that randomized svd works for complex matrices by comparing with linalg.svd
n_samples = 100
n_features = 500
rank = 5
k = 10
decimal = 7

# Create low rank complex valued matrix consisting of low rank real and imaginary
# parts, with the same column space.
rng = np.random.RandomState(42)
X = make_low_rank_matrix(
n_samples=n_samples,
n_features=n_features,
effective_rank=rank,
tail_strength=0.0,
random_state=rng,
)
A = rng.randn(n_features, n_features)
X = X + 1.0j * (X @ A)

# compute the singular values of X using the slow exact method
U, s, Vt = linalg.svd(X, full_matrices=False)

for normalizer in ["auto", "LU", "QR"]: # 'none' would not be stable
# compute the singular values of X using the fast approximate method
Ua, sa, Va = randomized_svd(
X, k, power_iteration_normalizer=normalizer, random_state=0
)

assert Ua.shape == (n_samples, k)
assert sa.shape == (k,)
assert Va.shape == (k, n_features)

# ensure that the singular values of both methods are equal up to the
# real rank of the matrix
assert_almost_equal(s[:k], sa, decimal=decimal)

# check the singular vectors too (while not checking the sign)
assert_almost_equal(
np.dot(U[:, :k], Vt[:k, :]), np.dot(Ua, Va), decimal=decimal
)

# check the sparse matrix representation
for csr_container in CSR_CONTAINERS:
X = csr_container(X)

# compute the singular values of X using the fast approximate method
Ua, sa, Va = randomized_svd(
X, k, power_iteration_normalizer=normalizer, random_state=0
)

assert_almost_equal(s[:rank], sa[:rank], decimal=decimal)


def test_cartesian():
# Check if cartesian product delivers the right results

Expand Down