Skip to content

fit.transform not equal to fit_transform in NMF with solver 'mu' and init 'nndsvda' #18663

Open
@cmarmo

Description

@cmarmo

Describe the bug

The test verifying that fit.transform == fit_transform failed for the NMF estimator with solver='mu' and init 'nndsvda'

Discovered in #16948, see also #18505.

cc @jeremiedbb @TomDLT , maybe @vene . Thanks!

Steps/Code to Reproduce

from sklearn.utils.estimator_checks import check_transformer_general
from sklearn.decomposition import NMF

check_transformer_general('NMF', NMF(init='nndsvda', max_iter=1000, solver='mu'))

Expected Results

The test passes.

Actual Results

Traceback (most recent call last):
  File "debug_nmf.py", line 4, in <module>
    check_transformer_general(
  File "/home/cmarmo/software/scikit-learn/sklearn/utils/_testing.py", line 302, in wrapper
    return fn(*args, **kwargs)
  File "/home/cmarmo/software/scikit-learn/sklearn/utils/estimator_checks.py", line 1305, in check_transformer_general
    _check_transformer(name, transformer, X, y)
  File "/home/cmarmo/software/scikit-learn/sklearn/utils/estimator_checks.py", line 1390, in _check_transformer
    assert_allclose_dense_sparse(
  File "/home/cmarmo/software/scikit-learn/sklearn/utils/_testing.py", line 409, in assert_allclose_dense_sparse
    assert_allclose(x, y, rtol=rtol, atol=atol, err_msg=err_msg)
  File "/home/cmarmo/.skldevenv/lib64/python3.8/site-packages/numpy/testing/_private/utils.py", line 1532, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/cmarmo/.skldevenv/lib64/python3.8/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0.01
fit_transform and transform outcomes not consistent in NMF(init='nndsvda', max_iter=1000, random_state=0, solver='mu')
Mismatched elements: 18 / 90 (20%)
Max absolute difference: 0.04219191
Max relative difference: 1.
 x: array([[5.584139e-01, 5.114790e-01, 4.818158e-01],
       [7.176271e-02, 9.486441e-02, 2.635424e-01],
       [1.848232e-01, 1.241254e-01, 9.756835e-22],...
 y: array([[5.566442e-01, 5.129683e-01, 4.835650e-01],
       [7.235885e-02, 9.501380e-02, 2.622291e-01],
       [1.846617e-01, 1.244062e-01, 1.105227e-07],...

Versions

System:
    python: 3.8.5 (default, Aug 12 2020, 00:00:00)  [GCC 10.2.1 20200723 (Red Hat 10.2.1-1)]
executable: /home/cmarmo/.skldevenv/bin/python
   machine: Linux-5.8.13-200.fc32.x86_64-x86_64-with-glibc2.2.5

Python dependencies:
          pip: 20.2.3
   setuptools: 41.6.0
      sklearn: 0.24.dev0
        numpy: 1.18.5
        scipy: 1.5.2
       Cython: 0.29.21
       pandas: 1.1.0
   matplotlib: 3.3.0
       joblib: 0.16.0
threadpoolctl: 2.1.0

Built with OpenMP: True

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions