-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
I get a segmentation fault with some MPI primitives using cuda-enabled mpi. The issue seems to appear when xla is not initialized, as the error disappears if memory is allocated on the GPU before mpi4jax is imported.
Run command used (with gpus on two separate nodes):
MPI4JAX_USE_CUDA_MPI=1 mpiexec -npernode 1 python run.py
Contents of run.py
to reproduce the error:
import mpi4jax
from mpi4py import MPI
comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
rank = comm.Get_rank()
root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)
Error message:
--------------------------------------------------------------------------
Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec noticed that process rank 0 with PID 0 on node 0 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------
Workarounds
1: Using MPI4JAX_USE_CUDA_MPI=0
.
2: Importing jax and creating a DeviceArray fixes the problem, but it has to be added before import mpi4jax
. As an example, inserting this in the beginning of run.py
works
import jax.numpy as jnp
jnp.array(3.)
3: Some primitives (only tested mpi4jax.allreduce
) works just fine out of the box. This following piece of code doesn't crash before the bcast
rank_sum, _ = mpi4jax.allreduce(rank, op=MPI.SUM, comm=comm)
print(rank, rank_sum)
root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)
Versions
Python 3.8.6
OpenMPI 4.0.5-gcccuda-2020b
CUDA 11.1.1.GCC-10.2.0
mpi4py 3.1.1
mpi4jax 0.3.2
jax 0.2.21
jaxlib 0.1.71[cuda111]
Metadata
Metadata
Assignees
Labels
No labels