Skip to content

Commit cfa498c

Browse files
devashishd12ogrisel
authored andcommitted
FIX in randomized_svd flip sign
Flip sign according to `u` in both cases of `transpose`.
1 parent 83d1cf4 commit cfa498c

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

sklearn/utils/extmath.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,12 @@ def randomized_svd(M, n_components, n_oversamples=10, n_iter=2,
363363
U = np.dot(Q, Uhat)
364364

365365
if flip_sign:
366-
U, V = svd_flip(U, V)
366+
if not transpose:
367+
U, V = svd_flip(U, V)
368+
else:
369+
# In case of transpose u_based_decision=false
370+
# to actually flip based on u and not v.
371+
U, V = svd_flip(U, V, u_based_decision=False)
367372

368373
if transpose:
369374
# transpose back the results according to the input convention

sklearn/utils/tests/test_extmath.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import assert_array_equal
1515
from sklearn.utils.testing import assert_array_almost_equal
1616
from sklearn.utils.testing import assert_true
17+
from sklearn.utils.testing import assert_false
1718
from sklearn.utils.testing import assert_greater
1819
from sklearn.utils.testing import assert_raises
1920
from sklearn.utils.testing import skip_if_32bit
@@ -328,6 +329,37 @@ def test_randomized_svd_sign_flip():
328329
assert_almost_equal(np.dot(v2.T, v2), np.eye(2))
329330

330331

332+
def test_randomized_svd_sign_flip_with_transpose():
333+
# Check if the randomized_svd sign flipping is always done based on u
334+
# irrespective of transpose.
335+
# See https://github.com/scikit-learn/scikit-learn/issues/5608
336+
# for more details.
337+
def max_loading_is_positive(u, v):
338+
"""
339+
returns bool tuple indicating if the values maximising np.abs
340+
are positive across all rows for u and across all columns for v.
341+
"""
342+
u_based = (np.abs(u).max(axis=0) == u.max(axis=0)).all()
343+
v_based = (np.abs(v).max(axis=1) == v.max(axis=1)).all()
344+
return u_based, v_based
345+
346+
mat = np.arange(10 * 8).reshape(10, -1)
347+
348+
# Without transpose
349+
u_flipped, _, v_flipped = randomized_svd(mat, 3, flip_sign=True)
350+
u_based, v_based = max_loading_is_positive(u_flipped, v_flipped)
351+
assert_true(u_based)
352+
assert_false(v_based)
353+
354+
# With transpose
355+
u_flipped_with_transpose, _, v_flipped_with_transpose = randomized_svd(
356+
mat, 3, flip_sign=True, transpose=True)
357+
u_based, v_based = max_loading_is_positive(
358+
u_flipped_with_transpose, v_flipped_with_transpose)
359+
assert_true(u_based)
360+
assert_false(v_based)
361+
362+
331363
def test_cartesian():
332364
# Check if cartesian product delivers the right results
333365

0 commit comments

Comments
 (0)