-
Notifications
You must be signed in to change notification settings - Fork 32
Description
In one of EPFL's GPU clusters (Izar) we need to use OpenMPI compiled with Nvidia's compilers to have direct GPU-to-GPU communication in MPI4jax, and @inailuig found out how to do it.
I'll post the instructions here (just for the sake of completeness)
#for the env:
module load nvhpc/21.2-mpi gcc/8.4.0 python/3.7.7
export NVHPC="/ssoft/spack/arvine/v1/opt/spack/linux-rhel7-haswell/gcc-4.8.5/nvhpc-21.2-eso63gmamwuohzzccpew522i7ryqjb46/Linux_x86_64/21.2"
export OPENMPI_ROOT="${NVHPC}/comm_libs/openmpi"
export CUDA_ROOT="${NVHPC}/cuda/11.2"
export CUDA_LIBRARY="${CUDA_ROOT}/lib64"
export CUDA_INCLUDE="${CUDA_ROOT}/include"
export LD_LIBRARY_PATH="${OPENMPI_ROOT}/lib:${CUDA_LIBRARY}:${LD_LIBRARY_PATH}"
export LIBRARY_PATH="${OPENMPI_ROOT}/lib:${CUDA_LIBRARY}:${LIBRARY_PATH}"
export CPLUS_INCLUDE_PATH="${OPENMPI_ROOT}/include:${CUDA_INCLUDE}:${CPLUS_INCLUDE_PATH}"
export C_INCLUDE_PATH="${OPENMPI_ROOT}/include:${CUDA_INCLUDE}:${C_INCLUDE_PATH}"
export PATH="${OPENMPI_ROOT}/bin:${CUDA_ROOT}/bin:${PATH}"
then we install jax and MPI using a weird flag (don't ask me why, but it won't work otherwise. we think it's because of Nvidia's compiler being weird)
# activate the python venv
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
CFLAGS=-noswitcherror pip install mpi4py
Then to compile mpi4jax we need to edit python's include files
# add this line to venv/include/python3.7m/pymem.h to fix compile error
#include <stddef.h>
(seems like it's a python3.7 problem with the nvidia clang; would have to include it in mpi_xla_bridge_gpu.c before Python.h but idk how to do that in cython)
"""
Now the weird thing is that if I now run
pip install mpi4jax
the installation will fail with error
In file included from /ssoft/spack/arvine/v1/opt/spack/linux-rhel7-skylake_avx512/gcc-8.4.0/python-3.7.7-drpdlwdbo3lmtkcbckq227ypnzno4ek3/include/python3.7m/Python.h:63,
from mpi4jax/_src/xla_bridge/mpi_xla_bridge.c:26:
/ssoft/spack/arvine/v1/opt/spack/linux-rhel7-skylake_avx512/gcc-8.4.0/python-3.7.7-drpdlwdbo3lmtkcbckq227ypnzno4ek3/include/python3.7m/pymem.h:114:12: error: unknown type name 'wchar_t'
PyAPI_FUNC(wchar_t*) _PyMem_RawWcsdup(const wchar_t *str);
^~~~~~~
this error suggests that when compiling with pip
the compiler is not looking at the include files from the current distribution (that we edited) but to the ones of the original python version.
To workaround we have to run
git clone GitHub.com/mpi4jax/mpi4jax
cd mpi4jax
python setup.py build
python setup.py install
this way pip/python picks up the modified include files.
--
Of course editing the python's include files is not good. But if users can't install mpi4jax, it's our problem.
After all, mpi4py compiles just fine.
Maybe we can do something to work around Nvidia's bugs?