Skip to content

Commit 7829951

Browse files
authored
Compile with pip Nvidia packages... (#236)
If Nvidia pypi packages are installed, use those to compile and link mpi4jax, setting the rpath in the .so file accordingly.
1 parent fea68e4 commit 7829951

File tree

4 files changed

+161
-30
lines changed

4 files changed

+161
-30
lines changed

.github/workflows/build-gpu-ext.yml

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,40 @@ jobs:
2323
cuda: "12.0"
2424
- os: ubuntu-22.04
2525
cuda: "12.1"
26+
- os: ubuntu-22.04
27+
cuda: "pypi"
2628

2729
steps:
28-
- uses: actions/checkout@v2
30+
- uses: actions/checkout@v4
2931

3032
# make sure tags are fetched so we can get a version
3133
- run: |
3234
git fetch --prune --unshallow --tags
3335
3436
- name: Set up Python
35-
uses: actions/setup-python@v2
37+
uses: actions/setup-python@v5
3638
with:
37-
python-version: '3.x'
39+
python-version: '3.11'
3840

3941
- name: Install CUDA
4042
env:
4143
cuda: ${{ matrix.cuda }}
4244
run: |
43-
source ./conf/install-cuda-ubuntu.sh
44-
if [[ $? -eq 0 ]]; then
45-
# Set paths for subsequent steps, using ${CUDA_PATH}
46-
echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH"
47-
echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV
48-
echo "${CUDA_PATH}/bin" >> $GITHUB_PATH
49-
echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV
45+
if [[ "${cuda}" == 'pypi' ]]; then
46+
echo "Installing jax[cuda] from PyPI"
47+
pip install 'nvidia-cublas-cu12>=12.1.3.1'
48+
pip install 'nvidia-cuda-cupti-cu12>=12.1.105'
49+
pip install 'nvidia-cuda-nvcc-cu12>=12.1.105'
50+
pip install 'nvidia-cuda-runtime-cu12>=12.1.105'
51+
else
52+
source ./conf/install-cuda-ubuntu.sh
53+
if [[ $? -eq 0 ]]; then
54+
# Set paths for subsequent steps, using ${CUDA_PATH}
55+
echo "Adding CUDA to CUDA_PATH, PATH and LD_LIBRARY_PATH"
56+
echo "CUDA_PATH=${CUDA_PATH}" >> $GITHUB_ENV
57+
echo "${CUDA_PATH}/bin" >> $GITHUB_PATH
58+
echo "LD_LIBRARY_PATH=${CUDA_PATH}/lib:${LD_LIBRARY_PATH}" >> $GITHUB_ENV
59+
fi
5060
fi
5161
shell: bash
5262

README.rst

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,24 @@ Installation
2626
$ pip install mpi4jax # Pip
2727
$ conda install -c conda-forge mpi4jax # conda
2828
29-
If you use pip and don't have JAX installed already, you will also need to do:
29+
Depending on the different jax backends you want to use, you can install mpi4jax in the following way
3030

3131
.. code:: bash
3232
33-
$ pip install jaxlib
33+
# pip install 'jax[cpu]'
34+
$ pip install mpi4jax
3435
35-
(or an equivalent GPU-enabled version, `see the JAX installation instructions <https://github.com/google/jax#installation>`_)
36+
# pip install -U 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
37+
$ pip install cython
38+
$ pip install mpi4jax --no-build-isolation
39+
40+
# pip install -U 'jax[cuda12_local]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
41+
$ CUDA_ROOT=XXX pip install mpi4jax
42+
43+
# pip install -U 'jax[cuda12]'
44+
# Not yet supported
45+
46+
(for more informations on jax GPU distributions, `see the JAX installation instructions <https://github.com/google/jax#installation>`_)
3647

3748
In case your MPI installation is not detected correctly, `it can help to install mpi4py separately <https://mpi4py.readthedocs.io/en/stable/install.html>`_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:
3849

docs/installation.rst

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Start by `installing a suitable version of JAX and jaxlib <https://github.com/go
88

99
.. code:: bash
1010
11-
$ pip install jax jaxlib
11+
$ pip install 'jax[cpu]'
1212
1313
.. note::
1414

@@ -67,13 +67,26 @@ Installation with NVIDIA GPU support (CUDA)
6767

6868
.. note::
6969

70-
To use JAX on the GPU, make sure that your ``jaxlib`` is `built with CUDA support <https://github.com/google/jax#installation>`_.
70+
There are 3 ways to install jax with CUDA support:
71+
- using a pypi-distributed CUDA installation (suggested by jax developers) ``pip install -U 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html``
72+
- using the locally-installed CUDA version, which must be compatible with jax. ``pip install -U 'jax[cuda12_local]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html``
73+
The procedure to install ``mpi4jax`` for the two situations is different.
74+
- using pip install -U 'jax[cuda12]', but this is not supported yet by ``mpi4jax``
7175

72-
``mpi4jax`` supports communication of JAX arrays stored in GPU memory.
76+
To use ``mpi4jax`` with pypi-distributed nvidia packages, which is the preferred way to install jax, you **must** install ``mpi4jax`` disabling
77+
the build-time-isolation in order for it to link to the libraries in the nvidia-cuda-nvcc-cu12 package. To do so, run the following command:
7378

74-
To build ``mpi4jax``'s GPU extensions, we need to be able to locate the CUDA headers on your system. If they are not detected automatically, you can set the environment variable :envvar:`CUDA_ROOT` when installing ``mpi4jax``::
79+
.. code:: bash
80+
81+
# assuming pip install -U 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html has been run
82+
$ pip install cython
83+
$ pip install mpi4jax --no-build-isolation
84+
85+
Alternatively, if you want to install ``mpi4jax`` with a locally-installed CUDA version, you can run the following command we need
86+
to be able to locate the CUDA headers on your system. If they are not detected automatically, you can set the environment
87+
variable :envvar:`CUDA_ROOT` when installing ``mpi4jax``::
7588

76-
$ CUDA_ROOT=/usr/local/cuda pip install mpi4jax
89+
$ CUDA_ROOT=/usr/local/cuda pip install --no-build-isolation mpi4jax
7790

7891
This is sufficient for most situations. However, ``mpi4jax`` will copy all data from GPU to CPU and back before and after invoking MPI.
7992

setup.py

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import sys
33
import shlex
44

5+
import importlib.util
6+
import pathlib
7+
import fnmatch
8+
59
from setuptools import setup, find_packages
610
from setuptools.extension import Extension
711
from setuptools.command.build_ext import build_ext
@@ -91,6 +95,30 @@ def build_extensions(self):
9195
# Cuda detection
9296

9397

98+
# partly taken from JAX
99+
# https://github.com/google/jax/blob/4cca2335220dcc953edd2ac764b2387e53527495/jax/_src/lib/__init__.py#L129
100+
def get_cuda_paths_from_nvidia_pypi():
101+
# try to check if nvidia-cuda-nvcc-cu* is installed
102+
# we need to get the site-packages of this install. to do so we use
103+
# mpi4py which must be installed
104+
mpi4py_spec = importlib.util.find_spec("mpi4py")
105+
depot_path = pathlib.Path(os.path.dirname(mpi4py_spec.origin)).parent
106+
107+
# If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have
108+
# both of the things XLA looks for in the cuda path, namely bin/ptxas and
109+
# nvvm/libdevice/libdevice.10.bc
110+
#
111+
# The files are split in two sets of directories, so we return both
112+
maybe_cuda_paths = [
113+
depot_path / "nvidia" / "cuda_nvcc",
114+
depot_path / "nvidia" / "cuda_runtime",
115+
]
116+
if all(p.is_dir() for p in maybe_cuda_paths):
117+
return [str(p) for p in maybe_cuda_paths]
118+
else:
119+
return []
120+
121+
94122
# Taken from CUPY (MIT License)
95123
def get_cuda_path():
96124
nvcc_path = search_on_path(("nvcc", "nvcc.exe"))
@@ -169,22 +197,86 @@ def get_sycl_info():
169197
sycl_info = get_sycl_info()
170198

171199

200+
def find_files(bases, pattern):
201+
"""Return list of files matching pattern in base folders and subfolders."""
202+
if isinstance(bases, (str, pathlib.Path)):
203+
bases = [bases]
204+
205+
result = []
206+
for base in bases:
207+
for root, dirs, files in os.walk(base):
208+
for name in files:
209+
if fnmatch.fnmatch(name, pattern):
210+
result.append(os.path.join(root, name))
211+
return result
212+
213+
172214
def get_cuda_info():
173-
cuda_info = {"compile": [], "libdirs": [], "libs": []}
174-
cuda_path = get_cuda_path()
175-
if not cuda_path:
176-
return cuda_info
215+
cuda_info = {"compile": [], "libdirs": [], "libs": [], "rpaths": []}
177216

178-
incdir = os.path.join(cuda_path, "include")
179-
if os.path.isdir(incdir):
180-
cuda_info["compile"].append(incdir)
217+
# First check if the nvidia-cuda-nvcc-cu* package is installed. We ignore CUDA_ROOT
218+
# because that is the same behaviour of jax.
219+
cuda_paths = get_cuda_paths_from_nvidia_pypi()
181220

182-
for libdir in ("lib64", "lib"):
183-
full_dir = os.path.join(cuda_path, libdir)
184-
if os.path.isdir(full_dir):
185-
cuda_info["libdirs"].append(full_dir)
221+
# If not, try to find the CUDA_PATH by hand
222+
if len(cuda_paths) > 0:
223+
nvidia_pypi_package = True
224+
else:
225+
nvidia_pypi_package = False
226+
_cuda_path = get_cuda_path()
227+
if _cuda_path is None:
228+
cuda_paths = []
229+
else:
230+
cuda_paths = [_cuda_path]
186231

187-
cuda_info["libs"].append("cudart")
232+
if len(cuda_paths) == 0:
233+
return cuda_info
234+
235+
for cuda_path in cuda_paths:
236+
incdir = os.path.join(cuda_path, "include")
237+
if os.path.isdir(incdir):
238+
cuda_info["compile"].append(incdir)
239+
240+
for libdir in ("lib64", "lib"):
241+
full_dir = os.path.join(cuda_path, libdir)
242+
if os.path.isdir(full_dir):
243+
cuda_info["libdirs"].append(full_dir)
244+
245+
# We need to link against libcudart.so
246+
# - If we are using standard CUDA installations, we simply add a link flag to
247+
# libcudart.so
248+
# - If we are using the nvidia-cuda-nvcc-cu* package, we need to find the exact
249+
# version of libcudart.so to link against because the the package does not provide
250+
# a generic binding to libcudart.so but only libcudart.so.XX.
251+
#
252+
# Moreover, if we are using nvidia-cuda-nvcc we must add @rpath (runtime search paths)
253+
# because we do not expect the user to set LD_LIBRARY_PATH to the nvidia-cuda-nvcc
254+
# package.
255+
if not nvidia_pypi_package:
256+
cuda_info["libs"].append("cudart")
257+
else:
258+
possible_libcudart = find_files(cuda_paths, "libcudart.so*")
259+
260+
if "libcudart.so" in possible_libcudart:
261+
# If generic symlink is present, use standard linker flag.
262+
# In theory with nvidia-cuda-nvcc-cu12 we should never reach this point
263+
# But in the future they might fix it.
264+
cuda_info["libs"].append("cudart")
265+
elif len(possible_libcudart) > 0:
266+
# This should be the standard case for nvidia-cuda-nvcc-cu*
267+
# where we find a library libcudart.so.XX . The syntax to link to a
268+
# specific version is -l:libcudart.so.XX
269+
# We arbitrarily choose the first one
270+
# and we add the runtime search path accordingly
271+
lib_to_link = possible_libcudart[0]
272+
cuda_info["libs"].append(f":{os.path.basename(lib_to_link)}")
273+
cuda_info["rpaths"].append(os.path.dirname(lib_to_link))
274+
else:
275+
# If we cannot find libcudart.so, we cannot build the extension
276+
# This should never happen with nvidia-cuda-nvcc-cu* package
277+
cuda_info["libs"].append("cudart")
278+
279+
print("\n\nCUDA INFO:", cuda_info, "\n\n")
188280
return cuda_info
189281

190282

@@ -237,13 +329,18 @@ def get_extensions():
237329
)
238330

239331
if cuda_info["compile"] and cuda_info["libdirs"]:
332+
extra_extension_args = {}
333+
if len(cuda_info["rpaths"]) > 0:
334+
extra_extension_args["runtime_library_dirs"] = cuda_info["rpaths"]
335+
240336
extensions.append(
241337
Extension(
242338
name=f"{CYTHON_SUBMODULE_NAME}.mpi_xla_bridge_gpu",
243339
sources=[f"{CYTHON_SUBMODULE_PATH}/mpi_xla_bridge_gpu.pyx"],
244340
include_dirs=cuda_info["compile"],
245341
library_dirs=cuda_info["libdirs"],
246342
libraries=cuda_info["libs"],
343+
**extra_extension_args,
247344
)
248345
)
249346
else:

0 commit comments

Comments
 (0)