Skip to content

[WIP] PERF Parallelize W/H updates of NMF with OpenMP #16439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions doc/modules/decomposition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,14 @@ stored components::
>>> X_new = np.array([[1, 0], [1, 6.1], [1, 0], [1, 4], [3.2, 1], [0, 4]])
>>> W_new = model.transform(X_new)

Low-level parallelism
---------------------

The coordinate descent solver (`solver='cd'`) uses OpenMP based parallelism
through Cython. The updates of `W` (resp. `H`) are computed in parallel over
the samples (resp. features). For more details on how to control the
number of threads, please refer to our :ref:`parallelism` notes.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_decomposition_plot_faces_decomposition.py`
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ Changelog
:func:`decomposition.non_negative_factorization` now preserves float32 dtype.
:pr:`16280` by :user:`Jeremie du Boisberranger <jeremiedbb>`.

- |Efficiency| Improved efficiency of :class:`manifold.NMF` when
`solver="cd"` by computing the updates of `W` (resp. `H`) in parallel over
the samples (resp. features). :pr:`16439` by :user:`<macg0406>` and
:user:`Jeremie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.ensemble`
.......................

Expand Down
29 changes: 14 additions & 15 deletions sklearn/decomposition/_cdnmf_fast.pyx
Original file line number Diff line number Diff line change
@@ -1,42 +1,41 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False

# cython: cdivision=True, boundscheck=False, wraparound=False
#
# Author: Mathieu Blondel, Tom Dupre la Tour
# License: BSD 3 clause

from cython cimport floating
from libc.math cimport fabs
from cython.parallel import prange

from ..utils._cython_blas cimport _dot
from ..utils._openmp_helpers import _openmp_effective_n_threads


def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt,
floating[:, :] XHt, Py_ssize_t[::1] permutation):
def _update_cdnmf_fast(floating[:, ::1] W, floating[:, ::1] HHt,
floating[:, ::1] XHt, Py_ssize_t[::1] permutation):
cdef:
floating violation = 0
Py_ssize_t n_components = W.shape[1]
Py_ssize_t n_samples = W.shape[0] # n_features for H update
floating grad, pg, hess
Py_ssize_t i, r, s, t
Py_ssize_t i, s, t
int num_threads = _openmp_effective_n_threads()

with nogil:
for s in range(n_components):
t = permutation[s]

for i in range(n_samples):
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt
grad = -XHt[i, t]

for r in range(n_components):
grad += HHt[t, r] * W[i, r]
for i in prange(n_samples, num_threads=num_threads):
# gradient = GW[i, t] where GW = np.dot(W, HHt.T) - XHt
grad = _dot(n_components, &HHt[t, 0], 1, &W[i, 0], 1) - XHt[i, t]

# projected gradient
pg = min(0., grad) if W[i, t] == 0 else grad
violation += fabs(pg)

# Hessian
hess = HHt[t, t]

if hess != 0:
W[i, t] = max(W[i, t] - grad / hess, 0.)

return violation