Skip to content

Commit 69e4c14

Browse files
devashishd12jakirkham
authored andcommitted
FIX in randomized_svd flip sign
Flip sign according to `u` in both cases of `transpose`.
1 parent 8a326f8 commit 69e4c14

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

sklearn/utils/extmath.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,12 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter=0,
305305
U = np.dot(Q, Uhat)
306306

307307
if flip_sign:
308-
U, V = svd_flip(U, V)
308+
if not transpose:
309+
U, V = svd_flip(U, V)
310+
else:
311+
# In case of transpose u_based_decision=false
312+
# to actually flip based on u and not v.
313+
U, V = svd_flip(U, V, u_based_decision=False)
309314

310315
if transpose:
311316
# transpose back the results according to the input convention

sklearn/utils/tests/test_extmath.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import assert_array_equal
1515
from sklearn.utils.testing import assert_array_almost_equal
1616
from sklearn.utils.testing import assert_true
17+
from sklearn.utils.testing import assert_false
1718
from sklearn.utils.testing import assert_greater
1819
from sklearn.utils.testing import assert_raises
1920
from sklearn.utils.testing import skip_if_32bit
@@ -287,6 +288,37 @@ def test_randomized_svd_sign_flip():
287288
assert_almost_equal(np.dot(v2.T, v2), np.eye(2))
288289

289290

291+
def test_randomized_svd_sign_flip_with_transpose():
292+
# Check if the randomized_svd sign flipping is always done based on u
293+
# irrespective of transpose.
294+
# See https://github.com/scikit-learn/scikit-learn/issues/5608
295+
# for more details.
296+
def max_loading_is_positive(u, v):
297+
"""
298+
returns bool tuple indicating if the values maximising np.abs
299+
are positive across all rows for u and across all columns for v.
300+
"""
301+
u_based = (np.abs(u).max(axis=0) == u.max(axis=0)).all()
302+
v_based = (np.abs(v).max(axis=1) == v.max(axis=1)).all()
303+
return u_based, v_based
304+
305+
mat = np.arange(10 * 8).reshape(10, -1)
306+
307+
# Without transpose
308+
u_flipped, _, v_flipped = randomized_svd(mat, 3, flip_sign=True)
309+
u_based, v_based = max_loading_is_positive(u_flipped, v_flipped)
310+
assert_true(u_based)
311+
assert_false(v_based)
312+
313+
# With transpose
314+
u_flipped_with_transpose, _, v_flipped_with_transpose = randomized_svd(
315+
mat, 3, flip_sign=True, transpose=True)
316+
u_based, v_based = max_loading_is_positive(
317+
u_flipped_with_transpose, v_flipped_with_transpose)
318+
assert_true(u_based)
319+
assert_false(v_based)
320+
321+
290322
def test_cartesian():
291323
# Check if cartesian product delivers the right results
292324

0 commit comments

Comments
 (0)