|
14 | 14 | from sklearn.utils.testing import assert_array_equal
|
15 | 15 | from sklearn.utils.testing import assert_array_almost_equal
|
16 | 16 | from sklearn.utils.testing import assert_true
|
| 17 | +from sklearn.utils.testing import assert_false |
17 | 18 | from sklearn.utils.testing import assert_greater
|
18 | 19 | from sklearn.utils.testing import assert_raises
|
19 | 20 | from sklearn.utils.testing import skip_if_32bit
|
@@ -328,6 +329,37 @@ def test_randomized_svd_sign_flip():
|
328 | 329 | assert_almost_equal(np.dot(v2.T, v2), np.eye(2))
|
329 | 330 |
|
330 | 331 |
|
| 332 | +def test_randomized_svd_sign_flip_with_transpose(): |
| 333 | + # Check if the randomized_svd sign flipping is always done based on u |
| 334 | + # irrespective of transpose. |
| 335 | + # See https://github.com/scikit-learn/scikit-learn/issues/5608 |
| 336 | + # for more details. |
| 337 | + def max_loading_is_positive(u, v): |
| 338 | + """ |
| 339 | + returns bool tuple indicating if the values maximising np.abs |
| 340 | + are positive across all rows for u and across all columns for v. |
| 341 | + """ |
| 342 | + u_based = (np.abs(u).max(axis=0) == u.max(axis=0)).all() |
| 343 | + v_based = (np.abs(v).max(axis=1) == v.max(axis=1)).all() |
| 344 | + return u_based, v_based |
| 345 | + |
| 346 | + mat = np.arange(10 * 8).reshape(10, -1) |
| 347 | + |
| 348 | + # Without transpose |
| 349 | + u_flipped, _, v_flipped = randomized_svd(mat, 3, flip_sign=True) |
| 350 | + u_based, v_based = max_loading_is_positive(u_flipped, v_flipped) |
| 351 | + assert_true(u_based) |
| 352 | + assert_false(v_based) |
| 353 | + |
| 354 | + # With transpose |
| 355 | + u_flipped_with_transpose, _, v_flipped_with_transpose = randomized_svd( |
| 356 | + mat, 3, flip_sign=True, transpose=True) |
| 357 | + u_based, v_based = max_loading_is_positive( |
| 358 | + u_flipped_with_transpose, v_flipped_with_transpose) |
| 359 | + assert_true(u_based) |
| 360 | + assert_false(v_based) |
| 361 | + |
| 362 | + |
331 | 363 | def test_cartesian():
|
332 | 364 | # Check if cartesian product delivers the right results
|
333 | 365 |
|
|
0 commit comments