Skip to content

Commit a06d78a

Browse files
authored
Remove token mode (#279)
* remove token mode * fix test * fix test * I said fix * raise if a token argument is passed * update actions * 🐛
1 parent 6a7786d commit a06d78a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+820
-3649
lines changed

.github/workflows/mpi-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
- os: ubuntu-latest
3131
python-version: "3.10"
3232
mpi: openmpi
33-
jax-version: "0.4.35"
33+
jax-version: "0.5.1"
3434

3535
env:
3636
MPICH_INTERFACE_HOSTNAME: localhost

.github/workflows/pre-commit.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ jobs:
1212
runs-on: ubuntu-latest
1313

1414
steps:
15-
- uses: actions/checkout@v2
16-
- uses: actions/setup-python@v2
17-
- uses: pre-commit/action@v2.0.0
15+
- uses: actions/checkout@v4
16+
- uses: actions/setup-python@v5
17+
- uses: pre-commit/action@v3.0.1

docs/sharp-bits.rst

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,6 @@
33

44
Read ahead for some pitfalls, counter-intuitive behavior, and sharp edges that we had to introduce in order to make this work.
55

6-
.. _tokens:
7-
8-
Token management
9-
----------------
10-
11-
The compiler behind JAX, XLA, is not aware of the fact that MPI function calls such as :func:`~mpi4jax.send` or :func:`~mpi4jax.recv` must be performed in a specific order (in jargon, that they have *side-effects*). It is therefore free to reorder those calls. Reordering of MPI calls usually leads to deadlocks, e.g. when both processes end up receiving before sending (instead of send-receive, receive-send).
12-
13-
*Tokens* are JAX's way to ensure that XLA does not re-order statements with side effects by injecting a fake data dependency between them.
14-
15-
This means that you *have* to use proper token management to prevent reordering from occurring. Every communication primitive in ``mpi4jax`` returns a token as the last return object, which you have to pass to subsequent primitives within the same JIT block, like this:
16-
17-
.. code:: python
18-
19-
# DO NOT DO THIS
20-
mpi4jax.send(arr, comm=comm)
21-
new_arr, _ = mpi4jax.recv(arr, comm=comm)
22-
23-
# INSTEAD, DO THIS
24-
token = mpi4jax.send(arr, comm=comm)
25-
new_arr, token = mpi4jax.recv(arr, comm=comm, token=token)
26-
27-
Those functions will then be executed in the same order as the sequence of tokens, from first to last.
28-
29-
Automatic token management
30-
~~~~~~~~~~~~~~~~~~~~~~~~~~
31-
32-
An alternative to manual token management is to use the primitives from :mod:`mpi4jax.notoken`, which automatically manage tokens for you. For example, the following code is equivalent to the previous example:
33-
34-
.. code:: python
35-
36-
import mpi4jax.notoken
37-
38-
mpi4jax.notoken.send(arr, comm=comm)
39-
new_arr = mpi4jax.notoken.recv(arr, comm=comm)
40-
41-
42-
This will likely become the default behavior in the future.
43-
446
No in-place operations in JAX
457
-----------------------------
468

docs/usage.rst

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The result is an array full of the value 6, because each process adds its rank t
4242
Basic example: sending and receiving
4343
------------------------------------
4444

45-
``mpi4jax`` can of course also send and receive data without performing an operation on it. For this, you can use :func:`~mpi4jax.send` and :func:`~mpi4jax.recv`:
45+
``mpi4jax`` can send and receive data without performing an operation on it. For this, you can use :func:`~mpi4jax.send` and :func:`~mpi4jax.recv`:
4646

4747
.. _example_2:
4848

@@ -64,12 +64,12 @@ Basic example: sending and receiving
6464
# note: this could also use mpi4jax.sendrecv
6565
if rank == 0:
6666
# send, then receive
67-
token = mpi4jax.send(arr, dest=1, comm=comm)
68-
other_arr, token = mpi4jax.recv(arr, source=1, comm=comm, token=token)
67+
mpi4jax.send(arr, dest=1, comm=comm)
68+
other_arr = mpi4jax.recv(arr, source=1, comm=comm)
6969
else:
7070
# receive, then send
71-
other_arr, token = mpi4jax.recv(arr, source=0, comm=comm)
72-
token = mpi4jax.send(arr, dest=0, comm=comm, token=token)
71+
other_arr = mpi4jax.recv(arr, source=0, comm=comm)
72+
mpi4jax.send(arr, dest=0, comm=comm)
7373
7474
return other_arr
7575
@@ -91,18 +91,3 @@ Executing this shows that each process has received the data from the other proc
9191
[1. 1. 1.]]
9292
9393
For operations like this, the correct order of the :func:`~mpi4jax.send` / :func:`~mpi4jax.recv` calls is critical to prevent the program from deadlocking (e.g. when both processes wait for data at the same time).
94-
95-
In ``mpi4jax``, we enforce order of execution through *tokens*. In :ref:`the example code <example_2>`, you can see this behavior e.g. in the following lines:
96-
97-
.. code:: python
98-
99-
token = mpi4jax.send(arr, dest=1, comm=comm)
100-
other_arr, token = mpi4jax.recv(arr, source=1, comm=comm, token=token)
101-
102-
The first call to :func:`~mpi4jax.send` returns a token, which we then pass to :func:`~mpi4jax.recv`. :func:`~mpi4jax.recv` *also* returns a new token that we could pass to subsequent communication primitives.
103-
104-
Because of the nature of JAX, **using tokens to enforce order is not optional.** If you do not use correct token management, you will experience deadlocks and crashes.
105-
106-
.. seealso::
107-
108-
For more information on tokens, see :ref:`tokens`.

examples/shallow_water.py

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,15 @@ def get_initial_conditions():
162162
u0_local = u0_global[local_slice]
163163
v0_local = v0_global[local_slice]
164164

165-
token = jax.lax.create_token()
166-
h0_local, token = enforce_boundaries(h0_local, "h", token)
167-
u0_local, token = enforce_boundaries(u0_local, "u", token)
168-
v0_local, token = enforce_boundaries(v0_local, "v", token)
165+
h0_local = enforce_boundaries(h0_local, "h")
166+
u0_local = enforce_boundaries(u0_local, "u")
167+
v0_local = enforce_boundaries(v0_local, "v")
169168

170169
return h0_local, u0_local, v0_local
171170

172171

173172
@partial(jax.jit, static_argnums=(1,))
174-
def enforce_boundaries(arr, grid, token=None):
173+
def enforce_boundaries(arr, grid):
175174
"""Handle boundary exchange between processors.
176175
177176
This is where mpi4jax comes in!
@@ -222,9 +221,6 @@ def enforce_boundaries(arr, grid, token=None):
222221
if proc_idx[1] == nproc_x - 1:
223222
proc_neighbors["east"] = (proc_idx[0], 0)
224223

225-
if token is None:
226-
token = jax.lax.create_token()
227-
228224
for send_dir, recv_dir in zip(send_order, recv_order):
229225
send_proc = proc_neighbors[send_dir]
230226
recv_proc = proc_neighbors[recv_dir]
@@ -245,20 +241,17 @@ def enforce_boundaries(arr, grid, token=None):
245241
send_arr = arr[send_idx]
246242

247243
if send_proc is None:
248-
recv_arr, token = mpi4jax.recv(
249-
recv_arr, source=recv_proc, comm=mpi_comm, token=token
250-
)
244+
recv_arr = mpi4jax.recv(recv_arr, source=recv_proc, comm=mpi_comm)
251245
arr = arr.at[recv_idx].set(recv_arr)
252246
elif recv_proc is None:
253-
token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token)
247+
mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm)
254248
else:
255-
recv_arr, token = mpi4jax.sendrecv(
249+
recv_arr = mpi4jax.sendrecv(
256250
send_arr,
257251
recv_arr,
258252
source=recv_proc,
259253
dest=send_proc,
260254
comm=mpi_comm,
261-
token=token,
262255
)
263256
arr = arr.at[recv_idx].set(recv_arr)
264257

@@ -268,7 +261,7 @@ def enforce_boundaries(arr, grid, token=None):
268261
if grid == "v" and proc_idx[0] == nproc_y - 1:
269262
arr = arr.at[-2, :].set(0.0)
270263

271-
return arr, token
264+
return arr
272265

273266

274267
ModelState = namedtuple("ModelState", "h, u, v, dh, du, dv")
@@ -280,20 +273,18 @@ def shallow_water_step(state, is_first_step):
280273
281274
Returns modified model state.
282275
"""
283-
token = jax.lax.create_token()
284-
285276
h, u, v, dh, du, dv = state
286277

287278
hc = jnp.pad(h[1:-1, 1:-1], 1, "edge")
288-
hc, token = enforce_boundaries(hc, "h", token)
279+
hc = enforce_boundaries(hc, "h")
289280

290281
fe = jnp.empty_like(u)
291282
fn = jnp.empty_like(u)
292283

293284
fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1])
294285
fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1])
295-
fe, token = enforce_boundaries(fe, "u", token)
296-
fn, token = enforce_boundaries(fn, "v", token)
286+
fe = enforce_boundaries(fe, "u")
287+
fn = enforce_boundaries(fn, "v")
297288

298289
dh_new = dh.at[1:-1, 1:-1].set(
299290
-(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
@@ -312,7 +303,7 @@ def shallow_water_step(state, is_first_step):
312303
q = q.at[1:-1, 1:-1].mul(
313304
1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:]))
314305
)
315-
q, token = enforce_boundaries(q, "h", token)
306+
q = enforce_boundaries(q, "h")
316307

317308
du_new = du.at[1:-1, 1:-1].set(
318309
-GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
@@ -337,7 +328,7 @@ def shallow_water_step(state, is_first_step):
337328
+ 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2)
338329
)
339330
)
340-
ke, token = enforce_boundaries(ke, "h", token)
331+
ke = enforce_boundaries(ke, "h")
341332

342333
du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx)
343334
dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy)
@@ -369,9 +360,9 @@ def shallow_water_step(state, is_first_step):
369360
)
370361
)
371362

372-
h, token = enforce_boundaries(h, "h", token)
373-
u, token = enforce_boundaries(u, "u", token)
374-
v, token = enforce_boundaries(v, "v", token)
363+
h = enforce_boundaries(h, "h")
364+
u = enforce_boundaries(u, "u")
365+
v = enforce_boundaries(v, "v")
375366

376367
if LATERAL_VISCOSITY > 0:
377368
# lateral friction
@@ -381,8 +372,8 @@ def shallow_water_step(state, is_first_step):
381372
fn = fn.at[1:-1, 1:-1].set(
382373
LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy
383374
)
384-
fe, token = enforce_boundaries(fe, "u", token)
385-
fn, token = enforce_boundaries(fn, "v", token)
375+
fe = enforce_boundaries(fe, "u")
376+
fn = enforce_boundaries(fn, "v")
386377

387378
u = u.at[1:-1, 1:-1].add(
388379
dt
@@ -398,8 +389,8 @@ def shallow_water_step(state, is_first_step):
398389
fn = fn.at[1:-1, 1:-1].set(
399390
LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy
400391
)
401-
fe, token = enforce_boundaries(fe, "u", token)
402-
fn, token = enforce_boundaries(fn, "v", token)
392+
fe = enforce_boundaries(fe, "u")
393+
fn = enforce_boundaries(fn, "v")
403394

404395
v = v.at[1:-1, 1:-1].add(
405396
dt

0 commit comments

Comments
 (0)