Skip to content

[MRG] DOC More details about parallelism (joblib, openMP, MKL...) #15116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 15, 2019
25 changes: 11 additions & 14 deletions doc/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,23 +299,20 @@ documentation <https://docs.python.org/3/library/multiprocessing.html#contexts-a

.. _faq_mkl_threading:

Why does my job use more cores than specified with n_jobs under OSX or Linux?
-----------------------------------------------------------------------------
Why does my job use more cores than specified with n_jobs?
----------------------------------------------------------

This happens when vectorized numpy operations are handled by libraries such
as MKL or OpenBLAS.
This is because ``n_jobs`` only controls the number of jobs for
routines that are parallelized with ``joblib``, but parallel code can come
from other sources:

While scikit-learn adheres to the limit set by ``n_jobs``,
numpy operations vectorized using MKL (or OpenBLAS) will make use of multiple
threads within each scikit-learn job (thread or process).
- some routines may be parallelized with OpenMP (for code written in C or
Cython).
- scikit-learn relies a lot on numpy, which in turn may rely on numerical
libraries like MKL, OpenBLAS or BLIS which can provide parallel
implementations.

The number of threads used by the BLAS library can be set via an environment
variable. For example, to set the maximum number of threads to some integer
value ``N``, the following environment variables should be set:

* For MKL: ``export MKL_NUM_THREADS=N``

* For OpenBLAS: ``export OPENBLAS_NUM_THREADS=N``
For more details, please refer to our :ref:`Parallelism notes <parallelism>`.


Why is there no support for deep or reinforcement learning / Will there be support for deep or reinforcement learning in scikit-learn?
Expand Down
56 changes: 17 additions & 39 deletions doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1508,45 +1508,23 @@ functions or non-estimator constructors.
early.

``n_jobs``
This is used to specify how many concurrent processes/threads should be
used for parallelized routines. Scikit-learn uses one processor for
its processing by default, although it also makes use of NumPy, which
may be configured to use a threaded numerical processor library (like
MKL; see :ref:`FAQ <faq_mkl_threading>`).

``n_jobs`` is an int, specifying the maximum number of concurrently
running jobs. If set to -1, all CPUs are used. If 1 is given, no
joblib level parallelism is used at all, which is useful for
debugging. Even with ``n_jobs = 1``, parallelism may occur due to
numerical processing libraries (see :ref:`FAQ <faq_mkl_threading>`).
For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for
``n_jobs = -2``, all CPUs but one are used.

``n_jobs=None`` means *unset*; it will generally be interpreted as
``n_jobs=1``, unless the current :class:`joblib.Parallel` backend
context specifies otherwise.

The use of ``n_jobs``-based parallelism in estimators varies:

* Most often parallelism happens in :term:`fitting <fit>`, but
sometimes parallelism happens in prediction (e.g. in random forests).
* Some parallelism uses a multi-threading backend by default, some
a multi-processing backend. It is possible to override the default
backend by using :func:`sklearn.utils.parallel_backend`.
* Whether parallel processing is helpful at improving runtime depends
on many factors, and it's usually a good idea to experiment rather
than assuming that increasing the number of jobs is always a good
thing. *It can be highly detrimental to performance to run multiple
copies of some estimators or functions in parallel.*

Nested uses of ``n_jobs``-based parallelism with the same backend will
result in an exception.
So ``GridSearchCV(OneVsRestClassifier(SVC(), n_jobs=2), n_jobs=2)``
won't work.

When ``n_jobs`` is not 1, the estimator being parallelized must be
picklable. This means, for instance, that lambdas cannot be used
as estimator parameters.
This parameter is used to specify how many concurrent processes or
threads should be used for routines that are parallelized with
:term:`joblib`.

``n_jobs`` is an integer, specifying the maximum number of concurrently
running workers. If 1 is given, no joblib parallelism is used at all,
which is useful for debugging. If set to -1, all CPUs are used. For
``n_jobs`` below -1, (n_cpus + 1 + n_jobs) are used. For example with
``n_jobs=-2``, all CPUs but one are used.

``n_jobs`` is ``None`` by default, which means *unset*; it will
generally be interpreted as ``n_jobs=1``, unless the current
:class:`joblib.Parallel` backend context specifies otherwise.

For more details on the use of ``joblib`` and its interactions with
scikit-learn, please refer to our :ref:`parallelism notes
<parallelism>`.

``pos_label``
Value with which positive labels must be encoded in binary
Expand Down
150 changes: 141 additions & 9 deletions doc/modules/computing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -504,20 +504,152 @@ Links
- `Scipy sparse matrix formats documentation <https://docs.scipy.org/doc/scipy/reference/sparse.html>`_

Parallelism, resource management, and configuration
=====================================================
===================================================

.. _parallelism:

Parallel and distributed computing
-----------------------------------
Parallelism
-----------

Scikit-learn uses the `joblib <https://joblib.readthedocs.io/en/latest/>`__
library to enable parallel computing inside its estimators. See the
joblib documentation for the switches to control parallel computing.
Some scikit-learn estimators and utilities can parallelize costly operations
using multiple CPU cores, thanks to the following components:

- via the `joblib <https://joblib.readthedocs.io/en/latest/>`_ library. In
this case the number of threads or processes can be controlled with the
``n_jobs`` parameter.
- via OpenMP, used in C or Cython code.

In addition, some of the numpy routines that are used internally by
scikit-learn may also be parallelized if numpy is installed with specific
numerical libraries such as MKL, OpenBLAS, or BLIS.

We describe these 3 scenarios in the following subsections.

Joblib-based parallelism
........................

When the underlying implementation uses joblib, the number of workers
(threads or processes) that are spawned in parallel can be controled via the
``n_jobs`` parameter.

.. note::

Where (and how) parallelization happens in the estimators is currently
poorly documented. Please help us by improving our docs and tackle `issue
14228 <https://github.com/scikit-learn/scikit-learn/issues/14228>`_!

Joblib is able to support both multi-processing and multi-threading. Whether
joblib chooses to spawn a thread or a process depends on the **backend**
that it's using.

Scikit-learn generally relies on the ``loky`` backend, which is joblib's
default backend. Loky is a multi-processing backend. When doing
multi-processing, in order to avoid duplicating the memory in each process
(which isn't reasonable with big datasets), joblib will create a `memmap
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.memmap.html>`_
that all processes can share, when the data is bigger than 1MB.

In some specific cases (when the code that is run in parallel releases the
GIL), scikit-learn will indicate to ``joblib`` that a multi-threading
backend is preferable.

As a user, you may control the backend that joblib will use (regardless of
what scikit-learn recommends) by using a context manager::

from joblib import parallel_backend

with parallel_backend('threading', n_jobs=2):
# Your scikit-learn code here

Please refer to the `joblib's docs
<https://joblib.readthedocs.io/en/latest/parallel.html#thread-based-parallelism-vs-process-based-parallelism>`_
for more details.

In practice, whether parallelism is helpful at improving runtime depends on
many factors. It is usually a good idea to experiment rather than assuming
that increasing the number of workers is always a good thing. In some cases
it can be highly detrimental to performance to run multiple copies of some
estimators or functions in parallel (see oversubscription below).

OpenMP-based parallelism
........................

OpenMP is used to parallelize code written in Cython or C, relying on
multi-threading exclusively. By default (and unless joblib is trying to
avoid oversubscription), the implementation will use as many threads as
possible.

You can control the exact number of threads that are used via the
``OMP_NUM_THREADS`` environment variable::

OMP_NUM_THREADS=4 python my_script.py

Parallel Numpy routines from numerical libraries
................................................

Scikit-learn relies heavily on NumPy and SciPy, which internally call
multi-threaded linear algebra routines implemented in libraries such as MKL,
OpenBLAS or BLIS.

The number of threads used by the OpenBLAS, MKL or BLIS libraries can be set
via the ``MKL_NUM_THREADS``, ``OPENBLAS_NUM_THREADS``, and
``BLIS_NUM_THREADS`` environment variables.

Please note that scikit-learn has no direct control over these
implementations. Scikit-learn solely relies on Numpy and Scipy.

.. note::
At the time of writing (2019), NumPy and SciPy packages distributed on
pypi.org (used by ``pip``) and on the conda-forge channel are linked
with OpenBLAS, while conda packages shipped on the "defaults" channel
from anaconda.org are linked by default with MKL.


Oversubscription: spawning too many threads
...........................................

It is generally recommended to avoid using significantly more processes or
threads than the number of CPUs on a machine. Over-subscription happens when
a program is running too many threads at the same time.

Suppose you have a machine with 8 CPUs. Consider a case where you're running
a :class:`~GridSearchCV` (parallelized with joblib) with ``n_jobs=8`` over
a :class:`~HistGradientBoostingClassifier` (parallelized with OpenMP). Each
instance of :class:`~HistGradientBoostingClassifier` will spawn 8 threads
(since you have 8 CPUs). That's a total of ``8 * 8 = 64`` threads, which
leads to oversubscription of physical CPU resources and to scheduling
overhead.

Oversubscription can arise in the exact same fashion with parallelized
routines from MKL, OpenBLAS or BLIS that are nested in joblib calls.

Starting from ``joblib >= 0.14``, when the ``loky`` backend is used (which
is the default), joblib will tell its child **processes** to limit the
number of threads they can use, so as to avoid oversubscription. In practice
the heuristic that joblib uses is to tell the processes to use ``max_threads
= n_cpus // n_jobs``, via their corresponding environment variable. Back to
our example from above, since the joblib backend of :class:`~GridSearchCV`
is ``loky``, each process will only be able to use 1 thread instead of 8,
thus mitigating the oversubscription issue.

Note that:

- Manually setting one of the environment variables (``OMP_NUM_THREADS``,
``MKL_NUM_THREADS``, ``OPENBLAS_NUM_THREADS``, or ``BLIS_NUM_THREADS``)
will take precedence over what joblib tries to do. The total number of
threads will be ``n_jobs * <LIB>_NUM_THREADS``. Note that setting this
limit will also impact your computations in the main process, which will
only use ``<LIB>_NUM_THREADS``. Joblib exposes a context manager for
finer control over the number of threads in its workers (see joblib docs
linked below).
- Joblib is currently unable to avoid oversubscription in a
multi-threading context. It can only do so with the ``loky`` backend
(which spawns processes).

You will find additional details about joblib mitigation of oversubscription
in `joblib documentation
<https://joblib.readthedocs.io/en/latest/parallel.html#avoiding-over-subscription-of-cpu-ressources>`_.

Note that, by default, scikit-learn uses its embedded (vendored) version
of joblib. A configuration switch (documented below) controls this
behavior.

Configuration switches
-----------------------
Expand Down
6 changes: 2 additions & 4 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -946,10 +946,8 @@ Low-level parallelism

:class:`HistGradientBoostingClassifier` and
:class:`HistGradientBoostingRegressor` have implementations that use OpenMP
for parallelization through Cython. The number of threads that is used can
be changed using the ``OMP_NUM_THREADS`` environment variable. By default,
all available cores are used. Please refer to the OpenMP documentation for
details.
for parallelization through Cython. For more details on how to control the
number of threads, please refer to our :ref:`parallelism` notes.

The following parts are parallelized:

Expand Down