Skip to content

Truly parallel execution of pairwise_kernels and pairwise_distances #29587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
stepan-srsen opened this issue Jul 30, 2024 · 15 comments · Fixed by #29693
Closed

Truly parallel execution of pairwise_kernels and pairwise_distances #29587

stepan-srsen opened this issue Jul 30, 2024 · 15 comments · Fixed by #29693

Comments

@stepan-srsen
Copy link
Contributor

stepan-srsen commented Jul 30, 2024

Describe the workflow you want to enable

Both pairwise_kernels and pairwise_distances functions call _parallel_pairwise function, which is (contrary to its name) not parallel as it enforces the threading backend. Therefore, these functions are terribly slow, especially for computationally expensive user-defined metrics. I understand that the reasons for the threading backend are possibly large memory demands and data communication overhead but I suggest a different approach. Also, the documentation for these functions talks about parallel execution and processes which is currently simply not true.

Describe your proposed solution

The memory and data communication issues can be reduced by a smarter distribution of the input data to individual processes. Right now, only Y is sliced in the _parallel_pairwise function which is suboptimal for parallel processing. Both X and Y should be sliced to lower the demands for multiprocessing. For example for 100x100 X and Y distributed to 100 processes, we have to copy 100+1 inputs to every process when slicing only Y while only 10+10 when slicing both X and Y. As a result, multiprocessing can be allowed. Also, joblib does automatic memmapping in some cases.

Alternatively, at least the documentation for pairwise_kernels and pairwise_distances should be corrected.

Describe alternatives you've considered, if relevant

No response

Additional context

No response

@stepan-srsen stepan-srsen added Needs Triage Issue requires triage New Feature labels Jul 30, 2024
@ogrisel
Copy link
Member

ogrisel commented Jul 30, 2024

which is (contrary to its name) not parallel as it enforces the threading backend.
... Therefore, these functions are terribly slow, especially for computationally expensive user-defined metrics.

Some CPU intensive operations in numpy or pandas do release the GIL, so the threading backend might be a good choice.

Furthermore, CPython 3.13 will come with an optional build flag that should a allow to run without any GIL (also known as the "free threading" mode of Python). It's not yet operational (it can cause segfaults in some cases) but @ngoldbaum, @lesteve and others are making progress at identifying and fixing the remaining problems. See #28978 for a tracking issue on the scikit-learn side.

It could be interesting to see if that can help with your specific workload.

More generally, before deciding to embark on a specific refactoring strategy for _parallel_pairwise, it would be interesting to see working code for typical uses of pairwise_kernels and pairwise_distances where you identified performance problems, and maybe report some quick benchmark results on the impact of varying the n_jobs parameter.

Also, joblib does automatic memmapping in some cases.

I am worried about relying on this too much. In retrospect I have the feeling that it's a bit of brittle black magic that is complex to debug and maintain and reason about its performance implications (both in terms of memory usage and computational overhead).

More generally, process based parallelism can lead to hard to debug situations (see the recently opened #29579 for instance), so I would rather rely more on thread-based parallelism than process-based parallelism for scikit-learn in the future.

@ogrisel ogrisel added Needs Benchmarks A tag for the issues and PRs which require some benchmarks Needs Reproducible Code Issue requires reproducible code and removed Needs Triage Issue requires triage labels Jul 30, 2024
@stepan-srsen
Copy link
Contributor Author

Hi @ogrisel,
thanks for your reply.

Some CPU intensive operations in numpy or pandas do release the GIL, so the threading backend might be a good choice.

True, that's why wrote that it is especially severe for user-defined CPU-bound metric functions. I will try to put together some benchmark for a common sklearn kernel to see the effect there.

Furthermore, CPython 3.13 will come with an optional build flag that should a allow to run without any GIL (also known as the "free threading" mode of Python). It's not yet operational (it can cause segfaults in some cases) but @ngoldbaum, @lesteve and others are making progress at identifying and fixing the remaining problems. See #28978 for a tracking issue on the scikit-learn side.

It could be interesting to see if that can help with your specific workload.

That would be super nice to have. The question is how long it might take to make it practically usable in sklearn.

More generally, before deciding to embark on a specific refactoring strategy for _parallel_pairwise, it would be interesting to see working code for typical uses of pairwise_kernels and pairwise_distances where you identified performance problems, and maybe report some quick benchmark results on the impact of varying the n_jobs parameter.

I observed a huge impact for my user-defined CPU-intensive metric function but for most people, predefined sklearn metrics are more relevant so I will benchmark that first.

Also, joblib does automatic memmapping in some cases.

I am worried about relying on this too much. In retrospect I have the feeling that it's a bit of brittle black magic that is complex to debug and maintain and reason about its performance implications (both in terms of memory usage and computational overhead).

Totally agree, automatic memmapping was just a side note

More generally, process based parallelism can lead to hard to debug situations (see the recently opened #29579 for instance), so I would rather rely more on thread-based parallelism than process-based parallelism for scikit-learn in the future.

Sure, threading is better if it provides the same speed-up, it is just very often not the case with python. That's why many sklearn functions use the default joblib backend based on multiprocessing. If the free threading proves to work well and will be available in sklearn, it might help a lot. But as of now, multiprocessing is the only way to get full parallelism.

I will soon provide some benchmark data.

@lesteve
Copy link
Member

lesteve commented Jul 31, 2024

Furthermore, CPython 3.13 will come with an optional build flag that should a allow to run without any GIL (also known as the "free threading" mode of Python)

That would be super nice to have. The question is how long it might take to make it practically usable in sklearn.

This is already usable in the sense that we have a CI for CPython 3.13 free-threaded and all the scikit-learn tests pass. I you have time to try CPython 3.13 free-threaded on your particular use case, feed-back would be super useful and welcome 🙏.

As of today, you have a few different ways to install CPython 3.13 free-threaded locally (through conda, your Linux distribution package manager, python.org download, etc ...) see the py-free-threading.github.io doc. numpy, scipy, scikit-learn, and other projects have development wheels for CPython 3.13 free-threaded see this doc.

With this kind of "ongoing work" things, there may be some caveats along the road, for example my understanding is that there is currently a hit on single-threaded performance in CPython 3.13 free-threaded, but this will be improved in the future. Not sure about the details, this may be quickly mentioned at one point in Anthony Shaw's PyCon US 2024 talk video.

@stepan-srsen
Copy link
Contributor Author

stepan-srsen commented Jul 31, 2024

@ogrisel
I benchmarked the _parallel_pairwise function using the vectorized sklearn rbf_kernel metric and also non-vectorized user-defined rbf metric (as a simple example). Therefore, the 1st case is IO-bound and the 2nd case is CPU-bound. I tested both using the original _parallel_pairwise function and also modified versions allowing for multiprocessing for both Y slicing and X+Y slicing (not sure if my slicing is optimal though). Here are the results for the IO-bound case (linear and log scale):

Multithreading works in general better there. However, for the CPU-bound case, multiprocessing is way better:

In fact, threading has no effect in the case of CPU-bound task. Also, slicing of both X and Y decreases the communication and memory overhead and maybe there is even better way of slicing than I used. You can find the code below. The problem I see is that the user does not have the option to choose between threading and multiprocessing. As the joblib manual states, one should not hardcode the backend and use prefer option instead. In such case, the user can change the backend from outside the function. The problem is that the current version of _parallel_pairwise writes to the output array in place, which is the fastest way for threading, but also prevents using multiprocessing.

@lesteve
Good to know. I hope I will find some time to test it. I am really thrilled about it.

import numpy as np
from sklearn.metrics.pairwise import rbf_kernel, _parallel_pairwise, _return_float_dtype, euclidean_distances, _pairwise_callable
from sklearn.utils import gen_even_slices
from sklearn.utils.validation import _num_samples
import matplotlib.pyplot as plt
from joblib import parallel_config, effective_n_jobs, delayed, Parallel
from functools import partial
import time
import math

def rbf_kernel2(x, y, gamma=None):
    if gamma is None:
        gamma = 1.0 / len(x)
    k = 0.0
    for i in range(len(x)):
        k += (x[i]-y[i])**2
    k *= -gamma
    k = math.exp(k)
    return k

def _parallel_pairwise2(X, Y, func, n_jobs, **kwds):
    """Break the pairwise matrix in n_jobs even slices
    and compute them in parallel."""

    if Y is None:
        Y = X
    X, Y, dtype = _return_float_dtype(X, Y)

    if effective_n_jobs(n_jobs) == 1:
        return func(X, Y, **kwds)
    
    ret = np.empty((X.shape[0], Y.shape[0]), dtype=dtype, order="F")
    slices = list(gen_even_slices(_num_samples(Y), effective_n_jobs(n_jobs)))
    out = Parallel(n_jobs=n_jobs)(
        delayed(func)(X, Y[s], **kwds)
        for s in slices
    )
    i = 0
    for i, s in enumerate(slices):
        ret[:, s] = out[i]

    return ret

def _parallel_pairwise3(X, Y, func, n_jobs, **kwds):
    """Break the pairwise matrix in n_jobs even slices
    and compute them in parallel."""

    if Y is None:
        Y = X
    X, Y, dtype = _return_float_dtype(X, Y)

    if effective_n_jobs(n_jobs) == 1:
        return func(X, Y, **kwds)
    
    ret = np.empty((X.shape[0], Y.shape[0]), dtype=dtype, order="F")
    eff_jobs = effective_n_jobs(n_jobs)
    n_jobs1 = int(np.sqrt(eff_jobs))
    n_jobs2 = int(eff_jobs//n_jobs1)
    # print(n_jobs1, n_jobs2)
    slices1 = list(gen_even_slices(_num_samples(X), n_jobs1))
    slices2 = list(gen_even_slices(_num_samples(Y), n_jobs2))
    out = Parallel(n_jobs=n_jobs)(
        delayed(func)(X[s1], Y[s2], **kwds)
        for s1 in slices1 for s2 in slices2
    )
    i = 0
    for s1 in slices1:
        for s2 in slices2:
            ret[s1, s2] = out[i]
            i += 1

    return ret

for n_samples, metric, title in [(5000, rbf_kernel, 'IO-bound'), (100, partial(_pairwise_callable, metric=rbf_kernel2), 'CPU-bound')]:
    # n_samples = 1000
    n_jobs = np.arange(1,11)
    # metric = rbf_kernel
    # metric = partial(_pairwise_callable, metric=rbf_kernel2)
    X, Y = np.random.rand(n_samples, 500), np.random.rand(n_samples, 500)
    results = [[] for i in range(5)]
    for n in n_jobs:
        with parallel_config(backend='threading'):
            result = %timeit -o -n 1 -r 10 _parallel_pairwise(X, Y, metric, n_jobs=n)
            results[0].append(result)
            result = %timeit -o -n 1 -r 10 _parallel_pairwise2(X, Y, metric, n_jobs=n)
            results[1].append(result)
            result = %timeit -o -n 1 -r 10 _parallel_pairwise3(X, Y, metric, n_jobs=n)
            results[2].append(result)
        with parallel_config(backend='loky'):
            result = %timeit -o -n 1 -r 10 _parallel_pairwise2(X, Y, metric, n_jobs=n)
            results[3].append(result)
            result = %timeit -o -n 1 -r 10 _parallel_pairwise3(X, Y, metric, n_jobs=n)
            results[4].append(result)
        print()
    times = []
    for i in range(len(results)):
        times.append([np.mean(result.timings) for result in results[i]])
    for scale in ['linear', 'log']:
        plt.figure()
        plt.title('_parallel_pairwise, '+title+', '+str(scale)+' scale, '+str(n_samples)+' samples')
        plt.xlim(np.min(n_jobs),np.max(n_jobs))
        if scale=='linear':
            plt.ylim(0,np.max(times))
        plt.xlabel('# of CPUs')
        plt.ylabel('times [s]')
        plt.xscale(scale)
        plt.yscale(scale)
        plt.plot(n_jobs, times[0], label='threading: original')
        plt.plot(n_jobs, times[1], label='threading: mod. Y slicing')
        plt.plot(n_jobs, times[2], label='threading: mod. X+Y slicing')
        plt.plot(n_jobs, times[3], label='multiprocessing: mod. Y slicing')
        plt.plot(n_jobs, times[4], label='multiprocessing: mod. X+Y slicing')
        plt.legend()
        plt.show()

EDIT:
output of $ python -c 'import sklearn; sklearn.show_versions()'

System:
    python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
executable: /user/srsen/bin/anaconda3/envs/cheml/bin/python
   machine: Linux-4.18.0-348.23.1.el8_5.x86_64-x86_64-with-glibc2.28

Python dependencies:
      sklearn: 1.4.2
          pip: 24.0
   setuptools: 69.5.1
        numpy: 1.26.4
        scipy: 1.13.1
       Cython: 3.0.10
       pandas: None
   matplotlib: 3.8.4
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: mkl
    num_threads: 20
         prefix: libmkl_rt
       filepath: /user/srsen/bin/anaconda3/envs/cheml/lib/libmkl_rt.so.2
        version: 2023.1-Product
threading_layer: intel

       user_api: openmp
   internal_api: openmp
    num_threads: 20
         prefix: libiomp
       filepath: /user/srsen/bin/anaconda3/envs/cheml/lib/libiomp5.so
        version: None

       user_api: openmp
   internal_api: openmp
    num_threads: 20
         prefix: libgomp
       filepath: /user/srsen/bin/anaconda3/envs/cheml/lib/libgomp.so.1.0.0
        version: None

@ogrisel
Copy link
Member

ogrisel commented Aug 1, 2024

Thank you very much for your analysis. Looking forward to see the results with free-threading Python if someone has time to set it up.

@lesteve
Copy link
Member

lesteve commented Aug 1, 2024

@stepan-srsen could you add the output of python -c 'import sklearn; sklearn.show_versions()' to your earlier comment i.e. #29587 (comment), this could be useful for further reference 🙏.

I am trying to run your snippet and full-disclosure for now I am getting weird results on vanilla Python, so I must be doing something wrong I need to have a a closer look ...

@stepan-srsen
Copy link
Contributor Author

stepan-srsen commented Aug 2, 2024

@lesteve @ogrisel
Hi,
I found some time to test the free-threaded version and here are the results. For the IO-bound case without setting the OPENBLAS_NUM_THREADS environment variable (and therefore using also the internal OPENBLAS threading):

Everything, even multiprocessing was tested with the GIL turned off. None of the approaches actually accelerates the calculation in this case/setup when allowing parallelism inside the _parallel_pairwise function. However, if I set OPENBLAS_NUM_THREADS=1, I get the following results:

These results correspond to the results from the regular python setup above. Interestingly, the BLAS threading can speed up single-core execution but its interplay with the parallelism used within the _parallel_pairwise function can actually lead to worse results.

For the CPU-bound tasks, the picture is clear, free-threading performs the same as multiprocessing or even slightly better (for numbers of CPUs for which the slicing is suboptimal) and the best times are basically the same as in my previous setup. I set OPENBLAS_NUM_THREADS=1 but it does not play a role in this use case.

Dependencies versions in the new setup:

System:
    python: 3.13.0b4 | packaged by Anaconda, Inc. | experimental free-threading build (main, Jul 24 2024, 20:51:55) [GCC 11.2.0]
executable: /user/srsen/bin/anaconda3/envs/nogil/bin/python
   machine: Linux-4.18.0-348.23.1.el8_5.x86_64-x86_64-with-glibc2.28

Python dependencies:
      sklearn: 1.6.dev0
          pip: 24.1.2
   setuptools: None
        numpy: 2.1.0.dev0
        scipy: 1.15.0.dev0
       Cython: None
       pandas: None
   matplotlib: 3.10.0.dev503+gf4f40ba46f
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 20
         prefix: libscipy_openblas
       filepath: /user/srsen/bin/anaconda3/envs/nogil/lib/python3.13t/site-packages/numpy.libs/libscipy_openblas64_-99b71e71.so
        version: 0.3.27
threading_layer: pthreads
   architecture: Haswell

       user_api: blas
   internal_api: openblas
    num_threads: 20
         prefix: libscipy_openblas
       filepath: /user/srsen/bin/anaconda3/envs/nogil/lib/python3.13t/site-packages/scipy.libs/libscipy_openblas-c128ec02.so
        version: 0.3.27.dev
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 20
         prefix: libgomp
       filepath: /user/srsen/bin/anaconda3/envs/nogil/lib/python3.13t/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None

@lesteve
Copy link
Member

lesteve commented Aug 2, 2024

@stepan-srsen just to make sure you do set MKL_NUM_THREADS=1 (vanilla Python run since you are using MKL in your environment) or OPENBLAS_NUM_THREADS=1 (free-threaded Python since you are using OpenBLAS here)?

I think not setting MKL_NUM_THREADS=1 was one of the reason I was seeing surprising results originally.

I was also wondering how the outer-loop parallelism performance (i.e. what you are showing) with the BLAS library in single-threaded mode, compares to no parallellism in the outer loop and let the BLAS library use all the cores, but I did not have time to look at this ...

@ogrisel
Copy link
Member

ogrisel commented Aug 2, 2024

For the CPU-bound tasks, the picture is clearer, free-threading performs same as multiprocessing or even slightly better (for numbers of CPUs for which the slicing si suboptimal) and the best times are basically the same as in my previous setup.

That's already great news! Thanks very much for taking the time to run this.

For the "IO-bound" case it would be worth investigating more about the interaction with different implementations of multithreaded BLAS used internally by numpy and scikit-learn (via scipy).

@ogrisel
Copy link
Member

ogrisel commented Aug 2, 2024

Thanks for updating the benchmark results with OPENBLAS_NUM_THREADS=1. So all in all, free-threading seems to work as intended but might require some manual protection against oversubscription with BLAS (for instance using threadpoolctl.threadpool_limit) as we typically do for BLAS calls nested under Cython/prange/OpenMP in other parts of the scikit-learn code base.

Given those results I think we can keep scikit-learn pairwise_kernels unchanged and instead allocate more efforts in making scikit-learn and joblib work well with free-threading CPython and the threading backend by protecting against oversubscription.

@stepan-srsen
Copy link
Contributor Author

stepan-srsen commented Aug 2, 2024

@lesteve
Good point, I didn't take care of that in the second setup. I updated the results now.

@ogrisel @lesteve
The free threading works nicely but it is still in the development stage and I assume that it won't be compatible with plenty of libraries for quite some time. As @ogrisel pointed out, the oversubscription should be also taken care of. I still think that the threading should not be hardcoded as the joblib manual suggests but it should use the prefer keyword instead. Or at least, the documentation for pairwise_kernels and pairwise_distances should be updated and clearly state that it is just threading. I was surprised when I turned on what I thought was multiprocessing but it didn't help with my CPU-intensive calculation. The standard approach is that one can define the joblib backend from outside of the function using with parallel_config(backend=...): as I did in the examples above.

@ogrisel
Copy link
Member

ogrisel commented Aug 5, 2024

I still think that the threading should not be hardcoded as the joblib manual suggests but it should use the prefer keyword instead.

+1 for using the prefer keyword. Feel free to open a PR.

@stepan-srsen
Copy link
Contributor Author

stepan-srsen commented Aug 8, 2024

Hi,

@ogrisel @lesteve
the problem is that the function currently writes to the array in place which prohibits multiprocessing. We can use require='sharedmem' instead of prefer='threads' but that doesn't really help. So we would have to write to the array outside of the parallel environment as I do in the modified functions above to allow for multiprocessing. Btw, I found out that NOT enforcing the Fortran order (i.e. not using order="F") makes the code slightly faster in this case so the difference between writing in place and not writing in place is not that big for threading as it looks from the figures above. Maybe the best way would be to recognize if multiprocessing has been requested from outside of the function and write in place only for the threading backend. What do you think? Also, what kind of slicing should I use for the PR? In my own code, I slice both X and Y as it is the best for multiprocessing and I use prefer='threads' for predefined metrics and prefer='processes' for user-defined metrics.

Also, another thought: Will the python 3.13 support free-threading by default or will it need some special compilation? If it is the second case, then sklearn shouldn't just assume that the user compiles the python with some experimental feature.

BTW you can remove the Needs Benchmarks and Needs Reproducible Code tags.

@ogrisel ogrisel removed Needs Benchmarks A tag for the issues and PRs which require some benchmarks Needs Reproducible Code Issue requires reproducible code labels Aug 8, 2024
@ogrisel
Copy link
Member

ogrisel commented Aug 8, 2024

Also, another thought: Will the python 3.13 support free-threading by default or will it need some special compilation? If it is the second case, than sklearn shouldn't just assume that the user compiles the python with some experimental feature.

The official python 3.13 binaries provided by python.org will not be free-threaded by default but some conda channel (and maybe other distributions) will ship both free-threaded and gil-based python versions. scikit-learn will probably ship wheels for both ABIs and users will be free to install what they want. If conda-forge is ready in time we will also ship free-threaded scikit-learn packages on conda-forge.

I don't think that the full Python ecosystem will be 100% free-threading ready at the time of the CPython 3.13 release (in October) but I hope that most of the core scientific packages will be ready.

My point is that I would rather invest dev effort in making free-threading work as best as possible as nearly as possible (in 2024 or 2025) than invest efforts in working around GIL and process-based paralellism limitations in the medium term.

I am fine with short term pragmatic improvements for parallel based parallelism if they do not add to much complexity to our code bases (scikit-learn, joblib, loky). though.

@stepan-srsen
Copy link
Contributor Author

stepan-srsen commented Aug 19, 2024

@ogrisel
I totally get your point. I am afraid some complexity would have to be added to keep the performance of the threading approach for IO-bound tasks. Therefore, I will just create a PR to improve the documentation. I have added the final version of my function (with improved slicing and byte ordering) below just in case somebody is interested. As mentioned above, I use threading by default for predefined kernels and metrics and multiprocessing for user-defined functions. BTW SharedMemory from multiprocessing might also help with the overhead for multiprocessing.

Results using the function below and NOT using free threading:

def _parallel_pairwise4(X, Y, func, n_jobs, **kwds):
    """Break the pairwise matrix in n_jobs even slices
    and compute them in parallel."""

    if Y is None:
        Y = X

    if effective_n_jobs(n_jobs) == 1:
        return func(X, Y, **kwds)
    
    ret = np.empty((X.shape[0], Y.shape[0]), dtype=X.dtype)
    eff_jobs = effective_n_jobs(n_jobs)
    a = math.sqrt(eff_jobs/(len(X)*len(Y)))
    n_jobs1 = int(a*min(len(X), len(Y)))
    if n_jobs1 == 0:
        n_jobs1 = 1
    else:
        for i in range(n_jobs1, 0, -1):
            if eff_jobs%i==0:
                n_jobs1 = i
                break
    n_jobs2 = int(eff_jobs//n_jobs1)
    if len(Y) < len(X):
        n_jobs1, n_jobs2 = n_jobs2, n_jobs1
    slices1 = list(gen_even_slices(len(X), n_jobs1))
    slices2 = list(gen_even_slices(len(Y), n_jobs2))
    out = Parallel(n_jobs=n_jobs)(
        delayed(func)(X[s1], Y[s2], **kwds)
        for s1 in slices1 for s2 in slices2
    )
    i = 0
    for s1 in slices1:
        for s2 in slices2:
            ret[s1, s2] = out[i]
            i += 1

    return ret

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants