-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
I was pleasantly surprised by seeing that mpi4jax automatically supports custom MPI_Reduce
operations. See for example the following:
import numpy as np
from mpi4py import MPI
import mpi4jax
import jax
import jax.numpy as jnp
from functools import partial
rank = MPI.COMM_WORLD.rank
# create numpy arrays to reduce
src = (np.arange(8) + rank*8).reshape(4,2)
src[0] = rank
src = jnp.array(src)
dst = np.zeros_like(src)
MPI.COMM_WORLD.barrier()
print("starting")
MPI.COMM_WORLD.barrier()
def myadd(xmem, ymem, dt):
x = np.frombuffer(xmem, dtype=src.dtype)
y = np.frombuffer(ymem, dtype=src.dtype)
z = x + y
print("Rank %d reducing %s (%s) and %s (%s), yielding %s" % (rank, x, type(x), y, type(y), z))
y[:] = z
op = MPI.Op.Create(myadd, commute=True)
#MPI.COMM_WORLD.Reduce(src, dst, op)
jax.jit(partial(mpi4jax.reduce, op=op, root=0))(src)
if MPI.COMM_WORLD.rank == 0:
print("ANSWER: %s" % dst)
MPI.COMM_WORLD.barrier()
print("-------------------------------")
MPI.COMM_WORLD.barrier()
jax.jit(partial(mpi4jax.allreduce, op=op))(src)
if MPI.COMM_WORLD.rank == 0:
print("ANSWER: %s" % dst)
However this works because mpi4py calls back into the python runtime, which will slow down the execution.
Ideally I'd like to use numba CFFI to define custom operations without calling back into python...
Metadata
Metadata
Assignees
Labels
No labels