|
14 | 14 | from sklearn.utils.testing import assert_array_equal
|
15 | 15 | from sklearn.utils.testing import assert_array_almost_equal
|
16 | 16 | from sklearn.utils.testing import assert_true
|
| 17 | +from sklearn.utils.testing import assert_false |
17 | 18 | from sklearn.utils.testing import assert_greater
|
18 | 19 | from sklearn.utils.testing import assert_raises
|
19 | 20 | from sklearn.utils.testing import skip_if_32bit
|
@@ -287,6 +288,37 @@ def test_randomized_svd_sign_flip():
|
287 | 288 | assert_almost_equal(np.dot(v2.T, v2), np.eye(2))
|
288 | 289 |
|
289 | 290 |
|
| 291 | +def test_randomized_svd_sign_flip_with_transpose(): |
| 292 | + # Check if the randomized_svd sign flipping is always done based on u |
| 293 | + # irrespective of transpose. |
| 294 | + # See https://github.com/scikit-learn/scikit-learn/issues/5608 |
| 295 | + # for more details. |
| 296 | + def max_loading_is_positive(u, v): |
| 297 | + """ |
| 298 | + returns bool tuple indicating if the values maximising np.abs |
| 299 | + are positive across all rows for u and across all columns for v. |
| 300 | + """ |
| 301 | + u_based = (np.abs(u).max(axis=0) == u.max(axis=0)).all() |
| 302 | + v_based = (np.abs(v).max(axis=1) == v.max(axis=1)).all() |
| 303 | + return u_based, v_based |
| 304 | + |
| 305 | + mat = np.arange(10 * 8).reshape(10, -1) |
| 306 | + |
| 307 | + # Without transpose |
| 308 | + u_flipped, _, v_flipped = randomized_svd(mat, 3, flip_sign=True) |
| 309 | + u_based, v_based = max_loading_is_positive(u_flipped, v_flipped) |
| 310 | + assert_true(u_based) |
| 311 | + assert_false(v_based) |
| 312 | + |
| 313 | + # With transpose |
| 314 | + u_flipped_with_transpose, _, v_flipped_with_transpose = randomized_svd( |
| 315 | + mat, 3, flip_sign=True, transpose=True) |
| 316 | + u_based, v_based = max_loading_is_positive( |
| 317 | + u_flipped_with_transpose, v_flipped_with_transpose) |
| 318 | + assert_true(u_based) |
| 319 | + assert_false(v_based) |
| 320 | + |
| 321 | + |
290 | 322 | def test_cartesian():
|
291 | 323 | # Check if cartesian product delivers the right results
|
292 | 324 |
|
|
0 commit comments