Open
Description
Describe the bug
The test verifying that fit.transform == fit_transform
failed for the NMF
estimator with solver='mu'
and init 'nndsvda'
Discovered in #16948, see also #18505.
cc @jeremiedbb @TomDLT , maybe @vene . Thanks!
Steps/Code to Reproduce
from sklearn.utils.estimator_checks import check_transformer_general
from sklearn.decomposition import NMF
check_transformer_general('NMF', NMF(init='nndsvda', max_iter=1000, solver='mu'))
Expected Results
The test passes.
Actual Results
Traceback (most recent call last):
File "debug_nmf.py", line 4, in <module>
check_transformer_general(
File "/home/cmarmo/software/scikit-learn/sklearn/utils/_testing.py", line 302, in wrapper
return fn(*args, **kwargs)
File "/home/cmarmo/software/scikit-learn/sklearn/utils/estimator_checks.py", line 1305, in check_transformer_general
_check_transformer(name, transformer, X, y)
File "/home/cmarmo/software/scikit-learn/sklearn/utils/estimator_checks.py", line 1390, in _check_transformer
assert_allclose_dense_sparse(
File "/home/cmarmo/software/scikit-learn/sklearn/utils/_testing.py", line 409, in assert_allclose_dense_sparse
assert_allclose(x, y, rtol=rtol, atol=atol, err_msg=err_msg)
File "/home/cmarmo/.skldevenv/lib64/python3.8/site-packages/numpy/testing/_private/utils.py", line 1532, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File "/home/cmarmo/.skldevenv/lib64/python3.8/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0.01
fit_transform and transform outcomes not consistent in NMF(init='nndsvda', max_iter=1000, random_state=0, solver='mu')
Mismatched elements: 18 / 90 (20%)
Max absolute difference: 0.04219191
Max relative difference: 1.
x: array([[5.584139e-01, 5.114790e-01, 4.818158e-01],
[7.176271e-02, 9.486441e-02, 2.635424e-01],
[1.848232e-01, 1.241254e-01, 9.756835e-22],...
y: array([[5.566442e-01, 5.129683e-01, 4.835650e-01],
[7.235885e-02, 9.501380e-02, 2.622291e-01],
[1.846617e-01, 1.244062e-01, 1.105227e-07],...
Versions
System:
python: 3.8.5 (default, Aug 12 2020, 00:00:00) [GCC 10.2.1 20200723 (Red Hat 10.2.1-1)]
executable: /home/cmarmo/.skldevenv/bin/python
machine: Linux-5.8.13-200.fc32.x86_64-x86_64-with-glibc2.2.5
Python dependencies:
pip: 20.2.3
setuptools: 41.6.0
sklearn: 0.24.dev0
numpy: 1.18.5
scipy: 1.5.2
Cython: 0.29.21
pandas: 1.1.0
matplotlib: 3.3.0
joblib: 0.16.0
threadpoolctl: 2.1.0
Built with OpenMP: True