Skip to content

DOC Update "Parallelism, resource management, and configuration" section #24997

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 25, 2022
150 changes: 91 additions & 59 deletions doc/computing/parallelism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,46 @@ Parallelism, resource management, and configuration
Parallelism
-----------

Some scikit-learn estimators and utilities can parallelize costly operations
using multiple CPU cores, thanks to the following components:
Some scikit-learn estimators and utilities parallelize costly operations
using multiple CPU cores.

- via the `joblib <https://joblib.readthedocs.io/en/latest/>`_ library. In
this case the number of threads or processes can be controlled with the
``n_jobs`` parameter.
- via OpenMP, used in C or Cython code.
Depending on the type of estimator and sometimes the values of the
constructor parameters, this is either done:

In addition, some of the numpy routines that are used internally by
scikit-learn may also be parallelized if numpy is installed with specific
numerical libraries such as MKL, OpenBLAS, or BLIS.
- with higher-level parallelism via `joblib <https://joblib.readthedocs.io/en/latest/>`_.
- with lower-level parallelism via OpenMP, used in C or Cython code.
- with lower-level parallelism via BLAS, used by NumPy and SciPy for generic operations
on arrays.

We describe these 3 scenarios in the following subsections.
The `n_jobs` parameters of estimators always controls the amount of parallelism
managed by joblib (processes or threads depending on the joblib backend).
The thread-level parallelism managed by OpenMP in scikit-learn's own Cython code
or by BLAS & LAPACK libraries used by NumPy and SciPy operations used in scikit-learn
is always controlled by environment variables or `threadpoolctl` as explained below.
Note that some estimators can leverage all three kinds of parallelism at different
points of their training and prediction methods.

Joblib-based parallelism
........................
We describe these 3 types of parallelism in the following subsections in more details.

Higher-level parallelism with joblib
....................................

When the underlying implementation uses joblib, the number of workers
(threads or processes) that are spawned in parallel can be controlled via the
``n_jobs`` parameter.

.. note::

Where (and how) parallelization happens in the estimators is currently
poorly documented. Please help us by improving our docs and tackle `issue
14228 <https://github.com/scikit-learn/scikit-learn/issues/14228>`_!
Where (and how) parallelization happens in the estimators using joblib by
specifying `n_jobs` is currently poorly documented.
Please help us by improving our docs and tackle `issue 14228
<https://github.com/scikit-learn/scikit-learn/issues/14228>`_!

Joblib is able to support both multi-processing and multi-threading. Whether
joblib chooses to spawn a thread or a process depends on the **backend**
that it's using.

Scikit-learn generally relies on the ``loky`` backend, which is joblib's
scikit-learn generally relies on the ``loky`` backend, which is joblib's
default backend. Loky is a multi-processing backend. When doing
multi-processing, in order to avoid duplicating the memory in each process
(which isn't reasonable with big datasets), joblib will create a `memmap
Expand Down Expand Up @@ -70,40 +78,57 @@ that increasing the number of workers is always a good thing. In some cases
it can be highly detrimental to performance to run multiple copies of some
estimators or functions in parallel (see oversubscription below).

OpenMP-based parallelism
........................
Lower-level parallelism with OpenMP
...................................

OpenMP is used to parallelize code written in Cython or C, relying on
multi-threading exclusively. By default (and unless joblib is trying to
avoid oversubscription), the implementation will use as many threads as
possible.
multi-threading exclusively. By default, the implementations using OpenMP
will use as many threads as possible, i.e. as many threads as logical cores.

You can control the exact number of threads that are used via the
``OMP_NUM_THREADS`` environment variable:
You can control the exact number of threads that are used either:

.. prompt:: bash $
- via the ``OMP_NUM_THREADS`` environment variable, for instance when:
running a python script:

.. prompt:: bash $

OMP_NUM_THREADS=4 python my_script.py

OMP_NUM_THREADS=4 python my_script.py
- or via `threadpoolctl` as explained by `this piece of documentation
<https://github.com/joblib/threadpoolctl/#setting-the-maximum-size-of-thread-pools>`_.

Parallel Numpy routines from numerical libraries
................................................
Parallel NumPy and SciPy routines from numerical libraries
..........................................................

Scikit-learn relies heavily on NumPy and SciPy, which internally call
multi-threaded linear algebra routines implemented in libraries such as MKL,
OpenBLAS or BLIS.
scikit-learn relies heavily on NumPy and SciPy, which internally call
multi-threaded linear algebra routines (BLAS & LAPACK) implemented in libraries
such as MKL, OpenBLAS or BLIS.

The number of threads used by the OpenBLAS, MKL or BLIS libraries can be set
via the ``MKL_NUM_THREADS``, ``OPENBLAS_NUM_THREADS``, and
``BLIS_NUM_THREADS`` environment variables.
You can control the exact number of threads used by BLAS for each library
using environment variables, namely:

- ``MKL_NUM_THREADS`` sets the number of thread MKL uses,
- ``OPENBLAS_NUM_THREADS`` sets the number of threads OpenBLAS uses
- ``BLIS_NUM_THREADS`` sets the number of threads BLIS uses

Note that BLAS & LAPACK implementations can also be impacted by
`OMP_NUM_THREADS`. To check whether this is the case in your environment,
you can inspect how the number of threads effectively used by those libraries
is affected when running the the following command in a bash or zsh terminal
for different values of `OMP_NUM_THREADS`::

.. prompt:: bash $

Please note that scikit-learn has no direct control over these
implementations. Scikit-learn solely relies on Numpy and Scipy.
OMP_NUM_THREADS=2 python -m threadpoolctl -i numpy scipy

.. note::
At the time of writing (2019), NumPy and SciPy packages distributed on
pypi.org (used by ``pip``) and on the conda-forge channel are linked
with OpenBLAS, while conda packages shipped on the "defaults" channel
from anaconda.org are linked by default with MKL.
At the time of writing (2022), NumPy and SciPy packages which are
distributed on pypi.org (i.e. the ones installed via ``pip install``)
and on the conda-forge channel (i.e. the ones installed via
``conda install --channel conda-forge``) are linked with OpenBLAS, while
NumPy and SciPy packages packages shipped on the ``defaults`` conda
channel from Anaconda.org (i.e. the ones installed via ``conda install``)
are linked by default with MKL.


Oversubscription: spawning too many threads
Expand All @@ -120,8 +145,8 @@ with ``n_jobs=8`` over a
OpenMP). Each instance of
:class:`~sklearn.ensemble.HistGradientBoostingClassifier` will spawn 8 threads
(since you have 8 CPUs). That's a total of ``8 * 8 = 64`` threads, which
leads to oversubscription of physical CPU resources and to scheduling
overhead.
leads to oversubscription of threads for physical CPU resources and thus
to scheduling overhead.

Oversubscription can arise in the exact same fashion with parallelized
routines from MKL, OpenBLAS or BLIS that are nested in joblib calls.
Expand All @@ -146,38 +171,34 @@ Note that:
only use ``<LIB>_NUM_THREADS``. Joblib exposes a context manager for
finer control over the number of threads in its workers (see joblib docs
linked below).
- Joblib is currently unable to avoid oversubscription in a
multi-threading context. It can only do so with the ``loky`` backend
(which spawns processes).
- When joblib is configured to use the ``threading`` backend, there is no
mechanism to avoid oversubscriptions when calling into parallel native
libraries in the joblib-managed threads.
- All scikit-learn estimators that explicitly rely on OpenMP in their Cython code
always use `threadpoolctl` internally to automatically adapt the numbers of
threads used by OpenMP and potentially nested BLAS calls so as to avoid
oversubscription.

You will find additional details about joblib mitigation of oversubscription
in `joblib documentation
<https://joblib.readthedocs.io/en/latest/parallel.html#avoiding-over-subscription-of-cpu-resources>`_.

You will find additional details about parallelism in numerical python libraries
in `this document from Thomas J. Fan <https://thomasjpfan.github.io/parallelism-python-libraries-design/>`_.

Configuration switches
-----------------------

Python runtime
..............
Python API
..........

:func:`sklearn.set_config` controls the following behaviors:

`assume_finite`
~~~~~~~~~~~~~~~

Used to skip validation, which enables faster computations but may lead to
segmentation faults if the data contains NaNs.

`working_memory`
~~~~~~~~~~~~~~~~

The optimal size of temporary arrays used by some algorithms.
:func:`sklearn.set_config` and :func:`sklearn.config_context` can be used to change
parameters of the configuration which control aspect of parallelism.

.. _environment_variable:

Environment variables
......................
.....................

These environment variables should be set before importing scikit-learn.

Expand Down Expand Up @@ -277,3 +298,14 @@ float64 data.
When this environment variable is set to a non zero value, the `Cython`
derivative, `boundscheck` is set to `True`. This is useful for finding
segfaults.

`SKLEARN_PAIRWISE_DIST_CHUNK_SIZE`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This sets the size of chunk to be used by the underlying `PairwiseDistancesReductions`
implementations. The default value is `256` which has been showed to be adequate on
most machines.

Users looking for the best performance might want to tune this variable using
powers of 2 so as to get the best parallelism behavior for their hardware,
especially with respect to their caches' sizes.
18 changes: 12 additions & 6 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,12 @@ def pairwise_distances_argmin_min(
values = values.flatten()
indices = indices.flatten()
else:
# TODO: once BaseDistanceReductionDispatcher supports distance metrics
# for boolean datasets, we won't need to fallback to
# pairwise_distances_chunked anymore.
# Joblib-based backend, which is used when user-defined callable
# are passed for metric.

# This won't be used in the future once PairwiseDistancesReductions support:
# - DistanceMetrics which work on supposedly binary data
# - CSR-dense and dense-CSR case if 'euclidean' in metric.

# Turn off check for finiteness because this is costly and because arrays
# have already been validated.
Expand Down Expand Up @@ -800,9 +803,12 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", metric_kwargs
)
indices = indices.flatten()
else:
# TODO: once BaseDistanceReductionDispatcher supports distance metrics
# for boolean datasets, we won't need to fallback to
# pairwise_distances_chunked anymore.
# Joblib-based backend, which is used when user-defined callable
# are passed for metric.

# This won't be used in the future once PairwiseDistancesReductions support:
# - DistanceMetrics which work on supposedly binary data
# - CSR-dense and dense-CSR case if 'euclidean' in metric.

# Turn off check for finiteness because this is costly and because arrays
# have already been validated.
Expand Down
18 changes: 13 additions & 5 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,13 @@ class from an array representing our data set and ask who's
)

elif self._fit_method == "brute":
# TODO: should no longer be needed once ArgKmin
# is extended to accept sparse and/or float32 inputs.
# Joblib-based backend, which is used when user-defined callable
# are passed for metric.

# This won't be used in the future once PairwiseDistancesReductions
# support:
# - DistanceMetrics which work on supposedly binary data
# - CSR-dense and dense-CSR case if 'euclidean' in metric.
reduce_func = partial(
self._kneighbors_reduce_func,
n_neighbors=n_neighbors,
Expand Down Expand Up @@ -1173,9 +1177,13 @@ class from an array representing our data set and ask who's
)

elif self._fit_method == "brute":
# TODO: should no longer be needed once we have Cython-optimized
# implementation for radius queries, with support for sparse and/or
# float32 inputs.
# Joblib-based backend, which is used when user-defined callable
# are passed for metric.

# This won't be used in the future once PairwiseDistancesReductions
# support:
# - DistanceMetrics which work on supposedly binary data
# - CSR-dense and dense-CSR case if 'euclidean' in metric.

# for efficiency, use squared euclidean distances
if self.effective_metric_ == "euclidean":
Expand Down