-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Parallelization _update_cdnmf_fast , fast nmf #6641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Any benchmarks demonstrating the gains ? |
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt | ||
grad = -XHt[i, t] | ||
# grad = -XHt[i, t] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should remove the commented lines.
The standard way that we support parallel computing in scikit-learn is
via joblib (here you would use the "threading" backend. It is robust to
various compilers (not every compiler implements openmp), and gives an
explicit way of controling the number of CPU used. Any reason that you
didn't do it in this PR? We would need robustness to compiler supper and
an explicit control the the number of CPU used to include this in
scikit-learn.
Also, you have failing tests on some architectures.
|
My intuition is that joblib isn't appropriate for parallelizing tight loops. In any case, I would also be interested in speed up figures and if @macg0406 has time an openmp vs. joblib comparison. |
The main problem preventing us to use openmp in the inner loops of scikit-learn is the bad interaction (silent freeze) with multiprocessing used by joblib for instance in a GridSearchCV wrapper. More details here: This will be solved by a new process pool manager were are currently working on at https://github.com/tomMoral/loky and that is planned to be used to replace the default multiprocessing backend of joblib. This is not ready yet though. |
@ogrisel can you maybe create an issue on that to track it? I think we definitely need to put some work into getting OpenMP to work, but I have not really any insight into what's happening right now. |
+1! Very interested in this too. |
Done in #7650. |
Any chance you could merge master here @macg0406 ? OpenMP is now supported, this PR from 2016 becomes possible. cc @jeremiedbb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some comments. Also, all the commented lines need to be removed.
But after a quick profiling it appears that it's not the critical part of the algo at all (see below). Although it doesn't hurt the performances to use a prange here, the gain is very very small.
Edit: I only had 2 components. With more components it becomes the critical part
cdef double grad = 0 - xht | ||
cdef double pg | ||
for r in range(n_components): | ||
# for(int r =0;r<n_components;r++) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing cdef Py_ssize_t r
from cython.parallel import prange | ||
|
||
|
||
cdef inline double _update_cdnmf_samples(unsigned n_components,double xht, double [] HHt,double[] W,double hess,unsigned t) nogil: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following previous declarations, n_components
and r
and Py_ssize_t
|
||
|
||
def _update_cdnmf_fast(double[:, ::1] W, double[:, :] HHt, double[:, :] XHt, | ||
Py_ssize_t[::1] permutation): | ||
cdef double violation = 0 | ||
cdef Py_ssize_t n_components = W.shape[1] | ||
cdef Py_ssize_t n_samples = W.shape[0] # n_features for H update | ||
cdef double grad, pg, hess | ||
cdef double pg, hess |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pg is not used any more
if hess != 0: | ||
W[t] -= grad/hess | ||
if W[t] < 0 : | ||
W[t] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
W[t] = fmax(W[t], 0) is probably a bit faster
with nogil: | ||
for s in range(n_components): | ||
t = permutation[s] | ||
# Hessian | ||
hess = HHt[t, t] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to define hess
here. It can be done directly in _update_cdnmf_samples
from cython.parallel import prange | ||
|
||
|
||
cdef inline double _update_cdnmf_samples(unsigned n_components,double xht, double [] HHt,double[] W,double hess,unsigned t) nogil: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think double* W
and double* HHt
would be more clear.
|
||
|
||
def _update_cdnmf_fast(double[:, ::1] W, double[:, :] HHt, double[:, :] XHt, | ||
Py_ssize_t[::1] permutation): | ||
cdef double violation = 0 | ||
cdef Py_ssize_t n_components = W.shape[1] | ||
cdef Py_ssize_t n_samples = W.shape[0] # n_features for H update | ||
cdef double grad, pg, hess | ||
cdef double pg, hess | ||
cdef Py_ssize_t i, r, s, t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r
is unused here now
Should we introduce a |
Should we introduce a n_jobs parameter to NMF for the prange ?
Gut feeling: +1
|
+1 for |
@macg0406 Finally it's been decided to not expose a n_jobs parameter for that. However, the number of threads for the prange needs to be given by a helper
Are you still willing to work on this ? If you don't have time I can take over. |
A lot of things have changed on the parallelization side and this PR is superseded by #16439. I'm closing it. |
Parallelized the function "_update_cdnmf_fast" by using prange/openmp, so that the process of nmf will be faster. The environment variable OMP_NUM_THREADS can be used to the max thread used.