Skip to content

Segmentation fault on GPU to GPU communication #106

@halvarsu

Description

@halvarsu

I get a segmentation fault with some MPI primitives using cuda-enabled mpi. The issue seems to appear when xla is not initialized, as the error disappears if memory is allocated on the GPU before mpi4jax is imported.

Run command used (with gpus on two separate nodes):

MPI4JAX_USE_CUDA_MPI=1 mpiexec -npernode 1 python run.py

Contents of run.py to reproduce the error:

import mpi4jax
from mpi4py import MPI

comm = MPI.COMM_WORLD
comm_size = comm.Get_size()
rank = comm.Get_rank()

root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)

Error message:

--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec noticed that process rank 0 with PID 0 on node 0 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------

Workarounds

1: Using MPI4JAX_USE_CUDA_MPI=0.

2: Importing jax and creating a DeviceArray fixes the problem, but it has to be added before import mpi4jax. As an example, inserting this in the beginning of run.py works

import jax.numpy as jnp
jnp.array(3.)

3: Some primitives (only tested mpi4jax.allreduce) works just fine out of the box. This following piece of code doesn't crash before the bcast

rank_sum, _ = mpi4jax.allreduce(rank, op=MPI.SUM, comm=comm)
print(rank, rank_sum)

root_rank, _ = mpi4jax.bcast(rank, root=0, comm=comm)
print(rank, root_rank)

Versions

Python 3.8.6
OpenMPI 4.0.5-gcccuda-2020b
CUDA 11.1.1.GCC-10.2.0
mpi4py 3.1.1
mpi4jax 0.3.2
jax 0.2.21
jaxlib 0.1.71[cuda111]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions