Skip to content

Commit ffd431b

Browse files
authored
Extending mpi4jax with XPU support (#226)
1 parent 72bb133 commit ffd431b

40 files changed

+1661
-360
lines changed

.github/workflows/build-xpu-ext.yml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
name: Build XPU extensions
2+
3+
on:
4+
pull_request:
5+
6+
push:
7+
branches:
8+
- master
9+
10+
jobs:
11+
build:
12+
runs-on: ubuntu-22.04
13+
14+
strategy:
15+
fail-fast: false
16+
17+
steps:
18+
- uses: actions/checkout@v2
19+
20+
# make sure tags are fetched so we can get a version
21+
- run: |
22+
git fetch --prune --unshallow --tags
23+
24+
- name: Set up Python
25+
uses: actions/setup-python@v2
26+
with:
27+
python-version: '3.x'
28+
29+
- name: Install OneAPI components
30+
run: |
31+
wget -nv https://registrationcenter-download.intel.com/akdlm/IRC_NAS/bb99984f-370f-413d-bbec-38928d2458f2/l_dpcpp-cpp-compiler_p_2024.0.2.29_offline.sh -P $HOME/basekit
32+
chmod +x $HOME/basekit/l_dpcpp-cpp-compiler_p_2024.0.2.29_offline.sh
33+
bash $HOME/basekit/l_dpcpp-cpp-compiler_p_2024.0.2.29_offline.sh -f "$HOME/basekit" -a --install-dir "$HOME/basekit" --eula=accept --silent
34+
shell: bash
35+
36+
- name: Setup MPI (mpich)
37+
uses: mpi4py/setup-mpi@v1
38+
with:
39+
mpi: mpich
40+
41+
- name: Install dependencies
42+
run: |
43+
python -m pip install --upgrade pip
44+
pip install setuptools wheel mpi4py cython
45+
46+
- name: Build XPU extensions
47+
run: |
48+
source $HOME/basekit/setvars.sh
49+
python setup.py build_ext --inplace
50+
test -f mpi4jax/_src/xla_bridge/mpi_xla_bridge_xpu*.so

docs/installation.rst

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ usually sufficient to specify the ``MPICC`` environment variable *before* instal
6262
In doubt, please refer to `the mpi4py documentation <https://mpi4py.readthedocs.io/en/stable/install.html>`_.
6363

6464

65-
Installation with GPU support
66-
-----------------------------
65+
Installation with NVIDIA GPU support (CUDA)
66+
-------------------------------------------
6767

6868
.. note::
6969

7070
To use JAX on the GPU, make sure that your ``jaxlib`` is `built with CUDA support <https://github.com/google/jax#installation>`_.
7171

72-
``mpi4jax`` also supports JAX arrays stored in GPU memory.
72+
``mpi4jax`` supports communication of JAX arrays stored in GPU memory.
7373

7474
To build ``mpi4jax``'s GPU extensions, we need to be able to locate the CUDA headers on your system. If they are not detected automatically, you can set the environment variable :envvar:`CUDA_ROOT` when installing ``mpi4jax``::
7575

@@ -86,3 +86,18 @@ If this is a bottleneck in your application, you can build MPI with CUDA support
8686
.. seealso::
8787

8888
Read :ref:`here <gpu-usage>` on how to use zero-copy GPU communication after installation.
89+
90+
91+
Installation with Intel GPU/XPU support
92+
---------------------------------------
93+
94+
``mpi4jax`` supports communication of JAX arrays stored in Intel GPU/XPU memory, via JAX's ``xpu`` backend.
95+
96+
**Requirements:**
97+
98+
- `Intel extension for OpenXLA <https://github.com/intel/intel-extension-for-openxla>`__ at least in version 0.3.0.
99+
- SYCL headers and libraries, which come as part of the `Intel oneAPI Base Toolkit <https://www.intel.com/content/www/us/en/developer/tools/oneapi/ai-analytics-toolkit.html>`__.
100+
- Optionally, `Intel MPI <https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/mpi-library.html>`__ with Intel XPU/GPU support.
101+
To leverage this, you also need to rebuild `mpi4py <https://mpi4py.readthedocs.io/en/stable/install.html>`__ to ensure it is linked to the XPU/GPU aware MPI implementation.
102+
103+
An example setup is found in the `mpi4jax test suite <https://github.com/mpi4jax/mpi4jax/tree/master/.github/workflows/build-xpu-ext.yml>`__.

docs/sharp-bits.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ Data will then be copied directly from GPU to GPU. If your MPI library
7878
does not have CUDA support, you will receive a segmentation fault when
7979
trying to access GPU memory.
8080

81+
Using Intel XPU aware MPI
82+
~~~~~~~~~~~~~~~~~~~~~~~~~
83+
84+
``mpi4jax`` is able to communicate data directly from and to Intel XPU
85+
and Intel GPU memory. This requires that you have installed MPI that is
86+
Intel GPU/XPU aware (MPI calls can work directly with XPU/GPU memory)
87+
and that JAX and `mpi4jax is built with Intel XPU
88+
support <installation>`__.
89+
90+
Currently, we cannot detect whether MPI is XPU/GPU aware. Therefore, by
91+
default, ``mpi4jax`` will not read directly from XPU/GPU memory, but
92+
instead copy to the CPU and back.
93+
94+
If you are certain that the underlying MPI library is XPU/GPU aware
95+
then, you can set the following environment variable:
96+
97+
.. code:: bash
98+
99+
$ export MPI4JAX_USE_SYCL_MPI=1
100+
101+
Data will then be copied directly from XPU to XPU. If your MPI library
102+
cannot work with Intel GPU/XPU buffers, you will receive a segmentation
103+
fault when trying to access mentioned GPU/XPU memory.
81104

82105
Using ``mpi4jax`` *and* ``mpi4py``
83106
----------------------------------

mpi4jax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
send,
2121
sendrecv,
2222
has_cuda_support,
23+
has_sycl_support,
2324
)
2425

2526
__all__ = [
@@ -36,4 +37,5 @@
3637
"send",
3738
"sendrecv",
3839
"has_cuda_support",
40+
"has_sycl_support",
3941
]

mpi4jax/_src/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .collective_ops.send import send # noqa: F401, E402
3131
from .collective_ops.sendrecv import sendrecv # noqa: F401, E402
3232

33-
from .utils import has_cuda_support # noqa: F401, E402
33+
from .utils import has_cuda_support, has_sycl_support # noqa: F401, E402
3434

3535
# sanitize namespace
3636
del jax_compat, xla_bridge, MPI, atexit, flush

mpi4jax/_src/collective_ops/allgather.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,16 @@
2121
prefer_notoken,
2222
)
2323
from ..jax_compat import custom_call, token_type, ShapedArray
24-
from ..decorators import translation_rule_cpu, translation_rule_gpu
24+
from ..decorators import (
25+
translation_rule_cpu,
26+
translation_rule_gpu,
27+
translation_rule_xpu,
28+
)
2529
from ..validation import enforce_types
2630
from ..comm import get_default_comm
2731

32+
from ..xla_bridge.device_descriptors import build_allgather_descriptor
33+
2834
# The Jax primitive
2935
mpi_allgather_p = Primitive("allgather_mpi") # Create the primitive
3036
mpi_allgather_impl = default_primitive_impl(mpi_allgather_p)
@@ -128,10 +134,7 @@ def mpi_allgather_xla_encode_cpu(ctx, sendbuf, token, comm):
128134
).results
129135

130136

131-
@translation_rule_gpu
132-
def mpi_allgather_xla_encode_gpu(ctx, sendbuf, token, comm):
133-
from ..xla_bridge.mpi_xla_bridge_gpu import build_allgather_descriptor
134-
137+
def mpi_allgather_xla_encode_device(ctx, sendbuf, token, comm):
135138
comm = unpack_hashable(comm)
136139

137140
sendbuf_aval, *_ = ctx.avals_in
@@ -177,6 +180,10 @@ def mpi_allgather_xla_encode_gpu(ctx, sendbuf, token, comm):
177180
).results
178181

179182

183+
mpi_allgather_xla_encode_xpu = translation_rule_xpu(mpi_allgather_xla_encode_device)
184+
mpi_allgather_xla_encode_gpu = translation_rule_gpu(mpi_allgather_xla_encode_device)
185+
186+
180187
# This function evaluates only the shapes during AST construction
181188
def mpi_allgather_abstract_eval(x, token, comm):
182189
comm = unpack_hashable(comm)
@@ -194,3 +201,4 @@ def mpi_allgather_abstract_eval(x, token, comm):
194201

195202
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu")
196203
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_gpu, platform="cuda")
204+
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/allreduce.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@
2222
prefer_notoken,
2323
)
2424
from ..jax_compat import custom_call, token_type, ShapedArray
25-
from ..decorators import translation_rule_cpu, translation_rule_gpu
25+
from ..decorators import (
26+
translation_rule_cpu,
27+
translation_rule_gpu,
28+
translation_rule_xpu,
29+
)
2630
from ..validation import enforce_types
2731
from ..comm import get_default_comm
2832

33+
from ..xla_bridge.device_descriptors import build_allreduce_descriptor
2934

3035
# The Jax primitive
3136
mpi_allreduce_p = Primitive("allreduce_mpi") # Create the primitive
@@ -122,10 +127,7 @@ def mpi_allreduce_xla_encode_cpu(ctx, x, token, op, comm, transpose):
122127
).results
123128

124129

125-
@translation_rule_gpu
126-
def mpi_allreduce_xla_encode_gpu(ctx, x, token, op, comm, transpose):
127-
from ..xla_bridge.mpi_xla_bridge_gpu import build_allreduce_descriptor
128-
130+
def mpi_allreduce_xla_encode_device(ctx, x, token, op, comm, transpose):
129131
op = unpack_hashable(op)
130132
comm = unpack_hashable(comm)
131133

@@ -171,6 +173,10 @@ def mpi_allreduce_xla_encode_gpu(ctx, x, token, op, comm, transpose):
171173
).results
172174

173175

176+
mpi_allreduce_xla_encode_gpu = translation_rule_gpu(mpi_allreduce_xla_encode_device)
177+
mpi_allreduce_xla_encode_xpu = translation_rule_xpu(mpi_allreduce_xla_encode_device)
178+
179+
174180
# This function evaluates only the shapes during AST construction
175181
def mpi_allreduce_abstract_eval(xs, token, op, comm, transpose):
176182
return (
@@ -230,3 +236,4 @@ def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm, transpose):
230236
# assign to the primitive the correct encoder
231237
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu")
232238
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_gpu, platform="cuda")
239+
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/alltoall.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
prefer_notoken,
2222
)
2323
from ..jax_compat import custom_call, token_type, ShapedArray
24-
from ..decorators import translation_rule_cpu, translation_rule_gpu
24+
from ..decorators import (
25+
translation_rule_cpu,
26+
translation_rule_gpu,
27+
translation_rule_xpu,
28+
)
2529
from ..validation import enforce_types
2630
from ..comm import get_default_comm
27-
31+
from ..xla_bridge.device_descriptors import build_alltoall_descriptor
2832

2933
# The Jax primitive
3034
mpi_alltoall_p = Primitive("alltoall_mpi") # Create the primitive
@@ -129,10 +133,7 @@ def mpi_alltoall_xla_encode_cpu(ctx, x, token, comm):
129133
).results
130134

131135

132-
@translation_rule_gpu
133-
def mpi_alltoall_xla_encode_gpu(ctx, x, token, comm):
134-
from ..xla_bridge.mpi_xla_bridge_gpu import build_alltoall_descriptor
135-
136+
def mpi_alltoall_xla_encode_device(ctx, x, token, comm):
136137
comm = unpack_hashable(comm)
137138

138139
x_aval, *_ = ctx.avals_in
@@ -180,6 +181,10 @@ def mpi_alltoall_xla_encode_gpu(ctx, x, token, comm):
180181
).results
181182

182183

184+
mpi_alltoall_xla_encode_xpu = translation_rule_xpu(mpi_alltoall_xla_encode_device)
185+
mpi_alltoall_xla_encode_gpu = translation_rule_gpu(mpi_alltoall_xla_encode_device)
186+
187+
183188
# This function evaluates only the shapes during AST construction
184189
def mpi_alltoall_abstract_eval(xs, token, comm):
185190
return (
@@ -195,3 +200,4 @@ def mpi_alltoall_abstract_eval(xs, token, comm):
195200
# assign to the primitive the correct encoder
196201
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu")
197202
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_gpu, platform="cuda")
203+
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/barrier.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,14 @@
2020
prefer_notoken,
2121
)
2222
from ..jax_compat import custom_call, token_type
23-
from ..decorators import translation_rule_cpu, translation_rule_gpu
23+
from ..decorators import (
24+
translation_rule_cpu,
25+
translation_rule_gpu,
26+
translation_rule_xpu,
27+
)
2428
from ..validation import enforce_types
2529
from ..comm import get_default_comm
30+
from ..xla_bridge.device_descriptors import build_barrier_descriptor
2631

2732

2833
# The Jax primitive
@@ -89,10 +94,7 @@ def mpi_barrier_xla_encode_cpu(ctx, token, comm):
8994
).results
9095

9196

92-
@translation_rule_gpu
93-
def mpi_barrier_xla_encode_gpu(ctx, token, comm):
94-
from ..xla_bridge.mpi_xla_bridge_gpu import build_barrier_descriptor
95-
97+
def mpi_barrier_xla_encode_device(ctx, token, comm):
9698
comm = unpack_hashable(comm)
9799

98100
out_types = token_type()
@@ -112,6 +114,10 @@ def mpi_barrier_xla_encode_gpu(ctx, token, comm):
112114
).results
113115

114116

117+
mpi_barrier_xla_encode_xpu = translation_rule_xpu(mpi_barrier_xla_encode_device)
118+
mpi_barrier_xla_encode_gpu = translation_rule_gpu(mpi_barrier_xla_encode_device)
119+
120+
115121
# This function evaluates only the shapes during AST construction
116122
def mpi_barrier_abstract_eval(token, comm):
117123
return core.abstract_token, {effect}
@@ -131,3 +137,4 @@ def mpi_barrier_batch_eval(in_args, batch_axes, comm):
131137
# assign to the primitive the correct encoder
132138
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu")
133139
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_gpu, platform="cuda")
140+
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/bcast.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121
prefer_notoken,
2222
)
2323
from ..jax_compat import custom_call, token_type, ShapedArray
24-
from ..decorators import translation_rule_cpu, translation_rule_gpu
24+
from ..decorators import (
25+
translation_rule_cpu,
26+
translation_rule_gpu,
27+
translation_rule_xpu,
28+
)
2529
from ..validation import enforce_types
2630
from ..comm import get_default_comm
31+
from ..xla_bridge.device_descriptors import build_bcast_descriptor
2732

2833

2934
# The Jax primitive
@@ -126,10 +131,7 @@ def mpi_bcast_xla_encode_cpu(ctx, x, token, root, comm):
126131
).results
127132

128133

129-
@translation_rule_gpu
130-
def mpi_bcast_xla_encode_gpu(ctx, x, token, root, comm):
131-
from ..xla_bridge.mpi_xla_bridge_gpu import build_bcast_descriptor
132-
134+
def mpi_bcast_xla_encode_device(ctx, x, token, root, comm):
133135
comm = unpack_hashable(comm)
134136

135137
x_aval, *_ = ctx.avals_in
@@ -176,6 +178,10 @@ def mpi_bcast_xla_encode_gpu(ctx, x, token, root, comm):
176178
).results
177179

178180

181+
mpi_bcast_xla_encode_xpu = translation_rule_xpu(mpi_bcast_xla_encode_device)
182+
mpi_bcast_xla_encode_gpu = translation_rule_gpu(mpi_bcast_xla_encode_device)
183+
184+
179185
# This function evaluates only the shapes during AST construction
180186
def mpi_bcast_abstract_eval(xs, token, root, comm):
181187
comm = unpack_hashable(comm)
@@ -199,3 +205,4 @@ def mpi_bcast_abstract_eval(xs, token, root, comm):
199205
# assign to the primitive the correct encoder
200206
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu")
201207
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_gpu, platform="cuda")
208+
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu")

0 commit comments

Comments
 (0)