Skip to content

[MRG+1] Add multiplicative-update solver in NMF, with all beta-divergence #5295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Dec 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6e071f2
ENH add multiplicative-update solver in NMF, with all beta-divergence
TomDLT Jan 25, 2016
7adab35
improve docstring
TomDLT Aug 22, 2016
874944d
DOC Add links to references
TomDLT Sep 1, 2016
f167ebf
FIX link to reference
TomDLT Sep 1, 2016
e7ca049
add warning when solver='mu' and init='nndsvd'
TomDLT Sep 20, 2016
4b5ac3f
revert all change in benchmark (separate PR)
TomDLT Sep 21, 2016
fa246d1
add example in the doc and docstring
TomDLT Sep 21, 2016
cbd78a5
change versionadded to 0.19
TomDLT Sep 21, 2016
4277ae6
fix doctest
TomDLT Sep 21, 2016
b7a63d4
address review's comments
TomDLT Oct 3, 2016
4ad79d7
Temporary: test a stopping criterion in nmf-MU
TomDLT Oct 3, 2016
bd6474a
update convergence criterion and tests to avoid warnings
TomDLT Oct 4, 2016
9923554
normalize convergence criterion with error_at_init
TomDLT Oct 4, 2016
44bffa7
Fix test adding a copy of shared inititalization
TomDLT Oct 4, 2016
71d2d12
add NMF with KL divergence in topic extraction example
TomDLT Oct 4, 2016
e572448
Fix add init parameter for custom init
TomDLT Oct 4, 2016
a732dae
decrease to 10 iteration between convergence test
TomDLT Oct 4, 2016
4f20b12
Fix the reconstruction error from x**2 / 2 to x
TomDLT Oct 4, 2016
3b18a45
fix init docstring
TomDLT Oct 4, 2016
6dff7c9
typo and improve test decreasing
TomDLT Oct 5, 2016
6b20e30
remove unused private function _safe_compute_error
TomDLT Oct 5, 2016
0ee3cbf
make beta_divergence function private
TomDLT Oct 6, 2016
f428d20
Remove deprecated ProjectedGradientNMF
TomDLT Oct 6, 2016
dd4d6b5
remove warning in test
TomDLT Oct 6, 2016
31f2c0c
update doc
TomDLT Oct 6, 2016
44779a1
FIX raise an error when beta_loss <= 0 and X contains zeros
TomDLT Oct 6, 2016
4713b1c
TYPO epsilson -> epsilon
TomDLT Oct 6, 2016
0d9fb50
remove other occurences of ProjectedGradientNMF
TomDLT Oct 6, 2016
42de807
add whats_new.rst entry
TomDLT Oct 6, 2016
2af4a23
minor leftovers
TomDLT Oct 6, 2016
259d827
non-ascii and nitpick
TomDLT Oct 7, 2016
057c70c
safe_min instead of min
TomDLT Oct 7, 2016
a9ac84a
solve conflict with master
TomDLT Nov 29, 2016
5e017c6
Merge branch 'master' into nmf_mu
TomDLT Dec 12, 2016
928ea89
minor doc update
TomDLT Dec 12, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ Samples generator

decomposition.PCA
decomposition.IncrementalPCA
decomposition.ProjectedGradientNMF
decomposition.KernelPCA
decomposition.FactorAnalysis
decomposition.FastICA
Expand Down Expand Up @@ -1058,7 +1057,7 @@ See the :ref:`metrics` section of the user guide for further details.
neighbors.DistanceMetric
neighbors.KernelDensity
neighbors.LocalOutlierFactor

.. autosummary::
:toctree: generated/
:template: function.rst
Expand Down
115 changes: 92 additions & 23 deletions doc/modules/decomposition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -648,27 +648,26 @@ components with some sparsity:
Non-negative matrix factorization (NMF or NNMF)
===============================================

:class:`NMF` is an alternative approach to decomposition that assumes that the
NMF with the Frobenius norm
---------------------------

:class:`NMF` [1]_ is an alternative approach to decomposition that assumes that the
data and the components are non-negative. :class:`NMF` can be plugged in
instead of :class:`PCA` or its variants, in the cases where the data matrix
does not contain negative values.
It finds a decomposition of samples :math:`X`
into two matrices :math:`W` and :math:`H` of non-negative elements,
by optimizing for the squared Frobenius norm:
does not contain negative values. It finds a decomposition of samples
:math:`X` into two matrices :math:`W` and :math:`H` of non-negative elements,
by optimizing the distance :math:`d` between :math:`X` and the matrix product
:math:`WH`. The most widely used distance function is the squared Frobenius
norm, which is an obvious extension of the Euclidean norm to matrices:

.. math::
\arg\min_{W,H} \frac{1}{2} ||X - WH||_{Fro}^2 = \frac{1}{2} \sum_{i,j} (X_{ij} - {WH}_{ij})^2

This norm is an obvious extension of the Euclidean norm to matrices. (Other
optimization objectives have been suggested in the NMF literature, in
particular Kullback-Leibler divergence, but these are not currently
implemented.)
d_{Fro}(X, Y) = \frac{1}{2} ||X - Y||_{Fro}^2 = \frac{1}{2} \sum_{i,j} (X_{ij} - {Y}_{ij})^2

Unlike :class:`PCA`, the representation of a vector is obtained in an additive
fashion, by superimposing the components, without subtracting. Such additive
models are efficient for representing images and text.

It has been observed in [Hoyer, 04] that, when carefully constrained,
It has been observed in [Hoyer, 2004] [2]_ that, when carefully constrained,
:class:`NMF` can produce a parts-based representation of the dataset,
resulting in interpretable models. The following example displays 16
sparse components found by :class:`NMF` from the images in the Olivetti
Expand All @@ -686,8 +685,8 @@ faces dataset, in comparison with the PCA eigenfaces.


The :attr:`init` attribute determines the initialization method applied, which
has a great impact on the performance of the method. :class:`NMF` implements
the method Nonnegative Double Singular Value Decomposition. NNDSVD is based on
has a great impact on the performance of the method. :class:`NMF` implements the
method Nonnegative Double Singular Value Decomposition. NNDSVD [4]_ is based on
two SVD processes, one approximating the data matrix, the other approximating
positive sections of the resulting partial SVD factors utilizing an algebraic
property of unit rank matrices. The basic NNDSVD algorithm is better fit for
Expand All @@ -696,6 +695,11 @@ the mean of all elements of the data), and NNDSVDar (in which the zeros are set
to random perturbations less than the mean of the data divided by 100) are
recommended in the dense case.

Note that the Multiplicative Update ('mu') solver cannot update zeros present in
the initialization, so it leads to poorer results when used jointly with the
basic NNDSVD algorithm which introduces a lot of zeros; in this case, NNDSVDa or
NNDSVDar should be preferred.

:class:`NMF` can also be initialized with correctly scaled random non-negative
matrices by setting :attr:`init="random"`. An integer seed or a
``RandomState`` can also be passed to :attr:`random_state` to control
Expand All @@ -716,7 +720,7 @@ and the intensity of the regularization with the :attr:`alpha`
and the regularized objective function is:

.. math::
\frac{1}{2}||X - WH||_{Fro}^2
d_{Fro}(X, WH)
+ \alpha \rho ||W||_1 + \alpha \rho ||H||_1
+ \frac{\alpha(1-\rho)}{2} ||W||_{Fro} ^ 2
+ \frac{\alpha(1-\rho)}{2} ||H||_{Fro} ^ 2
Expand All @@ -725,35 +729,100 @@ and the regularized objective function is:
:func:`non_negative_factorization` allows a finer control through the
:attr:`regularization` attribute, and may regularize only W, only H, or both.

NMF with a beta-divergence
--------------------------

As described previously, the most widely used distance function is the squared
Frobenius norm, which is an obvious extension of the Euclidean norm to
matrices:

.. math::
d_{Fro}(X, Y) = \frac{1}{2} ||X - Y||_{Fro}^2 = \frac{1}{2} \sum_{i,j} (X_{ij} - {Y}_{ij})^2

Other distance functions can be used in NMF as, for example, the (generalized)
Kullback-Leibler (KL) divergence, also referred as I-divergence:

.. math::
d_{KL}(X, Y) = \sum_{i,j} (X_{ij} log(\frac{X_{ij}}{Y_{ij}}) - X_{ij} + Y_{ij})

Or, the Itakura-Saito (IS) divergence:

.. math::
d_{IS}(X, Y) = \sum_{i,j} (\frac{X_{ij}}{Y_{ij}} - log(\frac{X_{ij}}{Y_{ij}}) - 1)

These three distances are special cases of the beta-divergence family, with
:math:`\beta = 2, 1, 0` respectively [6]_. The beta-divergence are
defined by :

.. math::
d_{\beta}(X, Y) = \sum_{i,j} \frac{1}{\beta(\beta - 1)}(X_{ij}^\beta + (\beta-1)Y_{ij}^\beta - \beta X_{ij} Y_{ij}^{\beta - 1})
Copy link
Member

@ogrisel ogrisel Oct 3, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please insert the figure of the beta divergence loss function example after this formula?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


.. figure:: ../auto_examples/decomposition/images/sphx_glr_plot_beta_divergence_001.png
:target: ../auto_examples/decomposition/plot_beta_divergence.html
:align: center
:scale: 75%

Note that this definition is not valid if :math:`\beta \in (0; 1)`, yet it can
be continously extended to the definitions of :math:`d_{KL}` and :math:`d_{IS}`
respectively.

:class:`NMF` implements two solvers, using Coordinate Descent ('cd') [5]_, and
Multiplicative Update ('mu') [6]_. The 'mu' solver can optimize every
beta-divergence, including of course the Frobenius norm (:math:`\beta=2`), the
(generalized) Kullback-Leibler divergence (:math:`\beta=1`) and the
Itakura-Saito divergence (:math:`\beta=0`). Note that for
:math:`\beta \in (1; 2)`, the 'mu' solver is significantly faster than for other
values of :math:`\beta`. Note also that with a negative (or 0, i.e.
'itakura-saito') :math:`\beta`, the input matrix cannot contain zero values.

The 'cd' solver can only optimize the Frobenius norm. Due to the
underlying non-convexity of NMF, the different solvers may converge to
different minima, even when optimizing the same distance function.

NMF is best used with the ``fit_transform`` method, which returns the matrix W.
The matrix H is stored into the fitted model in the ``components_`` attribute;
the method ``transform`` will decompose a new matrix X_new based on these
stored components::

>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [3, 1.2], [4, 1], [5, 0.8], [6, 1]])
>>> from sklearn.decomposition import NMF
>>> model = NMF(n_components=2, init='random', random_state=0)
>>> W = model.fit_transform(X)
>>> H = model.components_
>>> X_new = np.array([[1, 0], [1, 6.1], [1, 0], [1, 4], [3.2, 1], [0, 4]])
>>> W_new = model.transform(X_new)

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_decomposition_plot_faces_decomposition.py`
* :ref:`sphx_glr_auto_examples_applications_topics_extraction_with_nmf_lda.py`
* :ref:`sphx_glr_auto_examples_decomposition_plot_beta_divergence.py`

.. topic:: References:

* `"Learning the parts of objects by non-negative matrix factorization"
.. [1] `"Learning the parts of objects by non-negative matrix factorization"
<http://www.columbia.edu/~jwp2128/Teaching/W4721/papers/nmf_nature.pdf>`_
D. Lee, S. Seung, 1999

* `"Non-negative Matrix Factorization with Sparseness Constraints"
.. [2] `"Non-negative Matrix Factorization with Sparseness Constraints"
<http://www.jmlr.org/papers/volume5/hoyer04a/hoyer04a.pdf>`_
P. Hoyer, 2004

* `"Projected gradient methods for non-negative matrix factorization"
<http://www.csie.ntu.edu.tw/~cjlin/nmf/>`_
C.-J. Lin, 2007

* `"SVD based initialization: A head start for nonnegative
.. [4] `"SVD based initialization: A head start for nonnegative
matrix factorization"
<http://scgroup.hpclab.ceid.upatras.gr/faculty/stratis/Papers/HPCLAB020107.pdf>`_
C. Boutsidis, E. Gallopoulos, 2008

* `"Fast local algorithms for large scale nonnegative matrix and tensor
.. [5] `"Fast local algorithms for large scale nonnegative matrix and tensor
factorizations."
<http://www.bsp.brain.riken.jp/publications/2009/Cichocki-Phan-IEICE_col.pdf>`_
A. Cichocki, P. Anh-Huy, 2009

.. [6] `"Algorithms for nonnegative matrix factorization with the beta-divergence"
<http://http://arxiv.org/pdf/1010.1763v3.pdf>`_
C. Fevotte, J. Idier, 2011


.. _LatentDirichletAllocation:

Expand Down
26 changes: 16 additions & 10 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ New features
detection based on nearest neighbors.
:issue:`5279` by `Nicolas Goix`_ and `Alexandre Gramfort`_.

- The new solver ``mu`` implements a Multiplicate Update in
:class:`decomposition.NMF`, allowing the optimization of all
beta-divergences, including the Frobenius norm, the generalized
Kullback-Leibler divergence and the Itakura-Saito divergence.
By `Tom Dupre la Tour`_.

Enhancements
............

Expand Down Expand Up @@ -152,7 +158,7 @@ Bug fixes
with SVD and Eigen solver are now of the same length. :issue:`7632`
by :user:`JPFrancoia <JPFrancoia>`

- Fixes issue in :ref:`univariate_feature_selection` where score
- Fixes issue in :ref:`univariate_feature_selection` where score
functions were not accepting multi-label targets. :issue:`7676`
by `Mohammed Affan`_

Expand Down Expand Up @@ -382,7 +388,7 @@ Other estimators

- New :class:`mixture.GaussianMixture` and :class:`mixture.BayesianGaussianMixture`
replace former mixture models, employing faster inference
for sounder results. :issue:`7295` by :user:`Wei Xue <xuewei4d>` and
for sounder results. :issue:`7295` by :user:`Wei Xue <xuewei4d>` and
:user:`Thierry Guillemot <tguillemot>`.

- Class :class:`decomposition.RandomizedPCA` is now factored into :class:`decomposition.PCA`
Expand Down Expand Up @@ -505,7 +511,7 @@ Decomposition, manifold learning and clustering
- :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works
with ``np.float32`` and ``np.float64`` input data without converting it.
This allows to reduce the memory consumption by using ``np.float32``.
:issue:`6846` by :user:`Sebastian Säger <ssaeger>` and
:issue:`6846` by :user:`Sebastian Säger <ssaeger>` and
:user:`YenChen Lin <yenchenlin>`.

Preprocessing and feature selection
Expand All @@ -514,7 +520,7 @@ Preprocessing and feature selection
:issue:`5929` by :user:`Konstantin Podshumok <podshumok>`.

- :class:`feature_extraction.FeatureHasher` now accepts string values.
:issue:`6173` by :user:`Ryad Zenine <ryadzenine>` and
:issue:`6173` by :user:`Ryad Zenine <ryadzenine>` and
:user:`Devashish Deshpande <dsquareindia>`.

- Keyword arguments can now be supplied to ``func`` in
Expand All @@ -528,7 +534,7 @@ Preprocessing and feature selection
Model evaluation and meta-estimators

- :class:`multiclass.OneVsOneClassifier` and :class:`multiclass.OneVsRestClassifier`
now support ``partial_fit``. By :user:`Asish Panda <kaichogami>` and
now support ``partial_fit``. By :user:`Asish Panda <kaichogami>` and
:user:`Philipp Dowling <phdowling>`.

- Added support for substituting or disabling :class:`pipeline.Pipeline`
Expand Down Expand Up @@ -556,7 +562,7 @@ Metrics

- Added ``labels`` flag to :class:`metrics.log_loss` to to explicitly provide
the labels when the number of classes in ``y_true`` and ``y_pred`` differ.
:issue:`7239` by :user:`Hong Guangguo <hongguangguo>` with help from
:issue:`7239` by :user:`Hong Guangguo <hongguangguo>` with help from
:user:`Mads Jensen <indianajensen>` and :user:`Nelson Liu <nelson-liu>`.

- Support sparse contingency matrices in cluster evaluation
Expand Down Expand Up @@ -676,7 +682,7 @@ Decomposition, manifold learning and clustering
- Fixed incorrect initialization of :func:`utils.arpack.eigsh` on all
occurrences. Affects :class:`cluster.bicluster.SpectralBiclustering`,
:class:`decomposition.KernelPCA`, :class:`manifold.LocallyLinearEmbedding`,
and :class:`manifold.SpectralEmbedding` (:issue:`5012`). By
and :class:`manifold.SpectralEmbedding` (:issue:`5012`). By
:user:`Peter Fischer <yanlend>`.

- Attribute ``explained_variance_ratio_`` calculated with the SVD solver
Expand Down Expand Up @@ -959,7 +965,7 @@ New features
:class:`cross_validation.LabelShuffleSplit` generate train-test folds,
respectively similar to :class:`cross_validation.KFold` and
:class:`cross_validation.ShuffleSplit`, except that the folds are
conditioned on a label array. By `Brian McFee`_, :user:`Jean
conditioned on a label array. By `Brian McFee`_, :user:`Jean
Kossaifi <JeanKossaifi>` and `Gilles Louppe`_.

- :class:`decomposition.LatentDirichletAllocation` implements the Latent
Expand Down Expand Up @@ -1049,7 +1055,7 @@ Enhancements
By `Trevor Stephens`_.

- Provide an option for sparse output from
:func:`sklearn.metrics.pairwise.cosine_similarity`. By
:func:`sklearn.metrics.pairwise.cosine_similarity`. By
:user:`Jaidev Deshpande <jaidevd>`.

- Add :func:`minmax_scale` to provide a function interface for
Expand Down Expand Up @@ -1260,7 +1266,7 @@ Bug fixes
By `Tom Dupre la Tour`_.

- Fixed bug :issue:`5495` when
doing OVR(SVC(decision_function_shape="ovr")). Fixed by
doing OVR(SVC(decision_function_shape="ovr")). Fixed by
:user:`Elvis Dohmatob <dohmatob>`.


Expand Down
29 changes: 24 additions & 5 deletions examples/applications/topics_extraction_with_nmf_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
The output is a list of topics, each represented as a list of terms
(weights are not shown).

Non-negative Matrix Factorization is applied with two different objective
functions: the Frobenius norm, and the generalized Kullback-Leibler divergence.
The latter is equivalent to Probabilistic Latent Semantic Indexing.

The default parameters (n_samples / n_features / n_topics) should make
the example runnable in a couple of tens of seconds. You can try to
increase the dimensions of the problem, but be aware that the time
Expand Down Expand Up @@ -36,9 +40,10 @@

def print_top_words(model, feature_names, n_top_words):
for topic_idx, topic in enumerate(model.components_):
print("Topic #%d:" % topic_idx)
print(" ".join([feature_names[i]
for i in topic.argsort()[:-n_top_words - 1:-1]]))
message = "Topic #%d: " % topic_idx
message += " ".join([feature_names[i]
for i in topic.argsort()[:-n_top_words - 1:-1]])
print(message)
print()


Expand Down Expand Up @@ -71,17 +76,31 @@ def print_top_words(model, feature_names, n_top_words):
t0 = time()
tf = tf_vectorizer.fit_transform(data_samples)
print("done in %0.3fs." % (time() - t0))
print()

# Fit the NMF model
print("Fitting the NMF model with tf-idf features, "
print("Fitting the NMF model (Frobenius norm) with tf-idf features, "
"n_samples=%d and n_features=%d..."
% (n_samples, n_features))
t0 = time()
nmf = NMF(n_components=n_topics, random_state=1,
alpha=.1, l1_ratio=.5).fit(tfidf)
print("done in %0.3fs." % (time() - t0))

print("\nTopics in NMF model:")
print("\nTopics in NMF model (Frobenius norm):")
tfidf_feature_names = tfidf_vectorizer.get_feature_names()
print_top_words(nmf, tfidf_feature_names, n_top_words)

# Fit the NMF model
print("Fitting the NMF model (generalized Kullback-Leibler divergence) with "
"tf-idf features, n_samples=%d and n_features=%d..."
% (n_samples, n_features))
t0 = time()
nmf = NMF(n_components=n_topics, random_state=1, beta_loss='kullback-leibler',
solver='mu', max_iter=1000, alpha=.1, l1_ratio=.5).fit(tfidf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: kullback-leibler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

print("done in %0.3fs." % (time() - t0))

print("\nTopics in NMF model (generalized Kullback-Leibler divergence):")
tfidf_feature_names = tfidf_vectorizer.get_feature_names()
print_top_words(nmf, tfidf_feature_names, n_top_words)

Expand Down
29 changes: 29 additions & 0 deletions examples/decomposition/plot_beta_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
==============================
Beta-divergence loss functions
==============================

A plot that compares the various Beta-divergence loss functions supported by
the Multiplicative-Update ('mu') solver in :class:`sklearn.decomposition.NMF`.
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition.nmf import _beta_divergence

print(__doc__)

x = np.linspace(0.001, 4, 1000)
y = np.zeros(x.shape)

colors = 'mbgyr'
for j, beta in enumerate((0., 0.5, 1., 1.5, 2.)):
for i, xi in enumerate(x):
y[i] = _beta_divergence(1, xi, 1, beta)
name = "beta = %1.1f" % beta
plt.plot(x, y, label=name, color=colors[j])

plt.xlabel("x")
plt.title("beta-divergence(1, x)")
plt.legend(loc=0)
plt.axis([0, 4, 0, 3])
plt.show()
Loading