Skip to content

[MRG] Fixed NMF IndexError #11667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 12, 2019
10 changes: 10 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ Support for Python 3.4 and below has been officially dropped.
the default value is used.
:issue:`12988` by :user:`Zijie (ZJ) Poh <zjpoh>`.

:mod:`sklearn.decomposition`
............................

- |Fix| Fixed a bug in :class:`decomposition.NMF` where `init = 'nndsvd'`,
`init = 'nndsvda'`, and `init = 'nndsvdar'` are allowed when
`n_components < n_features` instead of
`n_components <= min(n_samples, n_features)`.
:issue:`11650` by :user:`Hossein Pourbozorg <hossein-pourbozorg>` and
:user:`Zijie (ZJ) Poh <zjpoh>`.

:mod:`sklearn.discriminant_analysis`
....................................

Expand Down
14 changes: 11 additions & 3 deletions sklearn/decomposition/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6,
Default: None.
Valid options:

- None: 'nndsvd' if n_components < n_features, otherwise 'random'.
- None: 'nndsvd' if n_components <= min(n_samples, n_features),
otherwise 'random'.

- 'random': non-negative random matrices, scaled with:
sqrt(X.mean() / n_components)
Expand Down Expand Up @@ -304,8 +305,14 @@ def _initialize_nmf(X, n_components, init=None, eps=1e-6,
check_non_negative(X, "NMF initialization")
n_samples, n_features = X.shape

if (init is not None and init != 'random'
and n_components > min(n_samples, n_features)):
raise ValueError("init = '{}' can only be used when "
"n_components <= min(n_samples, n_features)"
.format(init))

if init is None:
if n_components < n_features:
if n_components <= min(n_samples, n_features):
init = 'nndsvd'
else:
init = 'random'
Expand Down Expand Up @@ -1104,7 +1111,8 @@ class NMF(BaseEstimator, TransformerMixin):
Default: None.
Valid options:

- None: 'nndsvd' if n_components < n_features, otherwise random.
- None: 'nndsvd' if n_components <= min(n_samples, n_features),
otherwise random.

- 'random': non-negative random matrices, scaled with:
sqrt(X.mean() / n_components)
Expand Down
34 changes: 23 additions & 11 deletions sklearn/decomposition/tests/test_nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def test_parameter_checking():
clf = NMF(2, tol=0.1).fit(A)
assert_raise_message(ValueError, msg, clf.transform, -A)

for init in ['nndsvd', 'nndsvda', 'nndsvdar']:
msg = ("init = '{}' can only be used when "
"n_components <= min(n_samples, n_features)"
.format(init))
assert_raise_message(ValueError, msg, NMF(3, init).fit, A)
assert_raise_message(ValueError, msg, nmf._initialize_nmf, A,
3, init)


def test_initialize_close():
# Test NNDSVD error
Expand Down Expand Up @@ -197,17 +205,21 @@ def test_non_negative_factorization_consistency():
A = np.abs(rng.randn(10, 10))
A[:, 2 * np.arange(5)] = 0

for solver in ('cd', 'mu'):
W_nmf, H, _ = non_negative_factorization(
A, solver=solver, random_state=1, tol=1e-2)
W_nmf_2, _, _ = non_negative_factorization(
A, H=H, update_H=False, solver=solver, random_state=1, tol=1e-2)

model_class = NMF(solver=solver, random_state=1, tol=1e-2)
W_cls = model_class.fit_transform(A)
W_cls_2 = model_class.transform(A)
assert_array_almost_equal(W_nmf, W_cls, decimal=10)
assert_array_almost_equal(W_nmf_2, W_cls_2, decimal=10)
for init in ['random', 'nndsvd']:
for solver in ('cd', 'mu'):
W_nmf, H, _ = non_negative_factorization(
A, init=init, solver=solver, random_state=1, tol=1e-2)
W_nmf_2, _, _ = non_negative_factorization(
A, H=H, update_H=False, init=init, solver=solver,
random_state=1, tol=1e-2)

model_class = NMF(init=init, solver=solver, random_state=1,
tol=1e-2)
W_cls = model_class.fit_transform(A)
W_cls_2 = model_class.transform(A)

assert_array_almost_equal(W_nmf, W_cls, decimal=10)
assert_array_almost_equal(W_nmf_2, W_cls_2, decimal=10)


def test_non_negative_factorization_checking():
Expand Down