You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/sharp-bits.rst
-38Lines changed: 0 additions & 38 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,44 +3,6 @@
3
3
4
4
Read ahead for some pitfalls, counter-intuitive behavior, and sharp edges that we had to introduce in order to make this work.
5
5
6
-
.. _tokens:
7
-
8
-
Token management
9
-
----------------
10
-
11
-
The compiler behind JAX, XLA, is not aware of the fact that MPI function calls such as :func:`~mpi4jax.send` or :func:`~mpi4jax.recv` must be performed in a specific order (in jargon, that they have *side-effects*). It is therefore free to reorder those calls. Reordering of MPI calls usually leads to deadlocks, e.g. when both processes end up receiving before sending (instead of send-receive, receive-send).
12
-
13
-
*Tokens* are JAX's way to ensure that XLA does not re-order statements with side effects by injecting a fake data dependency between them.
14
-
15
-
This means that you *have* to use proper token management to prevent reordering from occurring. Every communication primitive in ``mpi4jax`` returns a token as the last return object, which you have to pass to subsequent primitives within the same JIT block, like this:
Those functions will then be executed in the same order as the sequence of tokens, from first to last.
28
-
29
-
Automatic token management
30
-
~~~~~~~~~~~~~~~~~~~~~~~~~~
31
-
32
-
An alternative to manual token management is to use the primitives from :mod:`mpi4jax.notoken`, which automatically manage tokens for you. For example, the following code is equivalent to the previous example:
33
-
34
-
.. code:: python
35
-
36
-
import mpi4jax.notoken
37
-
38
-
mpi4jax.notoken.send(arr, comm=comm)
39
-
new_arr = mpi4jax.notoken.recv(arr, comm=comm)
40
-
41
-
42
-
This will likely become the default behavior in the future.
Copy file name to clipboardExpand all lines: docs/usage.rst
+5-20Lines changed: 5 additions & 20 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -42,7 +42,7 @@ The result is an array full of the value 6, because each process adds its rank t
42
42
Basic example: sending and receiving
43
43
------------------------------------
44
44
45
-
``mpi4jax`` can of course also send and receive data without performing an operation on it. For this, you can use :func:`~mpi4jax.send` and :func:`~mpi4jax.recv`:
45
+
``mpi4jax`` can send and receive data without performing an operation on it. For this, you can use :func:`~mpi4jax.send` and :func:`~mpi4jax.recv`:
46
46
47
47
.. _example_2:
48
48
@@ -64,12 +64,12 @@ Basic example: sending and receiving
@@ -91,18 +91,3 @@ Executing this shows that each process has received the data from the other proc
91
91
[1. 1. 1.]]
92
92
93
93
For operations like this, the correct order of the :func:`~mpi4jax.send` / :func:`~mpi4jax.recv` calls is critical to prevent the program from deadlocking (e.g. when both processes wait for data at the same time).
94
-
95
-
In ``mpi4jax``, we enforce order of execution through *tokens*. In :ref:`the example code <example_2>`, you can see this behavior e.g. in the following lines:
The first call to :func:`~mpi4jax.send` returns a token, which we then pass to :func:`~mpi4jax.recv`. :func:`~mpi4jax.recv` *also* returns a new token that we could pass to subsequent communication primitives.
103
-
104
-
Because of the nature of JAX, **using tokens to enforce order is not optional.** If you do not use correct token management, you will experience deadlocks and crashes.
105
-
106
-
.. seealso::
107
-
108
-
For more information on tokens, see :ref:`tokens`.
0 commit comments