From 08b530aba9bb54eee2621efdacfdfab472546a20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dion=20H=C3=A4fner?= Date: Mon, 4 Nov 2024 13:37:44 +0100 Subject: [PATCH] xpu is back on the menu boys (#266) --- mpi4jax/_src/collective_ops/allgather.py | 10 +++++----- mpi4jax/_src/collective_ops/allreduce.py | 10 +++++----- mpi4jax/_src/collective_ops/alltoall.py | 10 +++++----- mpi4jax/_src/collective_ops/barrier.py | 9 ++++----- mpi4jax/_src/collective_ops/bcast.py | 10 +++++----- mpi4jax/_src/collective_ops/gather.py | 10 +++++----- mpi4jax/_src/collective_ops/recv.py | 10 +++++----- mpi4jax/_src/collective_ops/reduce.py | 10 +++++----- mpi4jax/_src/collective_ops/scan.py | 10 +++++----- mpi4jax/_src/collective_ops/scatter.py | 10 +++++----- mpi4jax/_src/collective_ops/send.py | 10 +++++----- mpi4jax/_src/collective_ops/sendrecv.py | 10 +++++----- mpi4jax/_src/jax_compat.py | 11 ++++++++++- .../experimental/notoken/collective_ops/allgather.py | 9 +++++---- .../experimental/notoken/collective_ops/allreduce.py | 9 +++++---- .../experimental/notoken/collective_ops/alltoall.py | 9 +++++---- .../experimental/notoken/collective_ops/barrier.py | 9 +++++---- mpi4jax/experimental/notoken/collective_ops/bcast.py | 9 +++++---- mpi4jax/experimental/notoken/collective_ops/gather.py | 9 +++++---- mpi4jax/experimental/notoken/collective_ops/recv.py | 9 +++++---- mpi4jax/experimental/notoken/collective_ops/reduce.py | 9 +++++---- mpi4jax/experimental/notoken/collective_ops/scan.py | 9 +++++---- .../experimental/notoken/collective_ops/scatter.py | 9 +++++---- mpi4jax/experimental/notoken/collective_ops/send.py | 9 +++++---- .../experimental/notoken/collective_ops/sendrecv.py | 9 +++++---- 25 files changed, 129 insertions(+), 109 deletions(-) diff --git a/mpi4jax/_src/collective_ops/allgather.py b/mpi4jax/_src/collective_ops/allgather.py index cb1f76a2..91b311b1 100644 --- a/mpi4jax/_src/collective_ops/allgather.py +++ b/mpi4jax/_src/collective_ops/allgather.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -199,6 +199,6 @@ def mpi_allgather_abstract_eval(x, token, comm): mpi_allgather_p.def_impl(mpi_allgather_impl) mpi_allgather_p.def_effectful_abstract_eval(mpi_allgather_abstract_eval) -mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu") +register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu") +register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cuda, platform="cuda") +register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/allreduce.py b/mpi4jax/_src/collective_ops/allreduce.py index 3de3436d..d521454a 100644 --- a/mpi4jax/_src/collective_ops/allreduce.py +++ b/mpi4jax/_src/collective_ops/allreduce.py @@ -6,7 +6,7 @@ from jax.interpreters import ad, batching from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -21,7 +21,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -232,6 +232,6 @@ def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm, transpose): ad.primitive_transposes[mpi_allreduce_p] = mpi_allreduce_transpose_rule # assign to the primitive the correct encoder -mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu") +register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu") +register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cuda, platform="cuda") +register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/alltoall.py b/mpi4jax/_src/collective_ops/alltoall.py index 7f6f6b08..a05f7cce 100644 --- a/mpi4jax/_src/collective_ops/alltoall.py +++ b/mpi4jax/_src/collective_ops/alltoall.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -198,6 +198,6 @@ def mpi_alltoall_abstract_eval(xs, token, comm): mpi_alltoall_p.def_effectful_abstract_eval(mpi_alltoall_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu") +register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu") +register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cuda, platform="cuda") +register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/barrier.py b/mpi4jax/_src/collective_ops/barrier.py index fa8ae069..63fb774b 100644 --- a/mpi4jax/_src/collective_ops/barrier.py +++ b/mpi4jax/_src/collective_ops/barrier.py @@ -6,7 +6,6 @@ from jax.interpreters import batching from jax.lax import create_token -from jax.interpreters import mlir from ..utils import ( HashableMPIType, @@ -19,7 +18,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type +from ..jax_compat import custom_call, register_lowering, token_type from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -135,6 +134,6 @@ def mpi_barrier_batch_eval(in_args, batch_axes, comm): batching.primitive_batchers[mpi_barrier_p] = mpi_barrier_batch_eval # assign to the primitive the correct encoder -mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu") +register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu") +register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cuda, platform="cuda") +register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/bcast.py b/mpi4jax/_src/collective_ops/bcast.py index 5264f212..2a1ef43d 100644 --- a/mpi4jax/_src/collective_ops/bcast.py +++ b/mpi4jax/_src/collective_ops/bcast.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -203,6 +203,6 @@ def mpi_bcast_abstract_eval(xs, token, root, comm): mpi_bcast_p.def_effectful_abstract_eval(mpi_bcast_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu") +register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu") +register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cuda, platform="cuda") +register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/gather.py b/mpi4jax/_src/collective_ops/gather.py index 466e9053..261d29a2 100644 --- a/mpi4jax/_src/collective_ops/gather.py +++ b/mpi4jax/_src/collective_ops/gather.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -237,6 +237,6 @@ def mpi_gather_abstract_eval(x, token, root, comm): mpi_gather_p.def_effectful_abstract_eval(mpi_gather_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_xpu, platform="xpu") +register_lowering(mpi_gather_p, mpi_gather_xla_encode_cpu, platform="cpu") +register_lowering(mpi_gather_p, mpi_gather_xla_encode_cuda, platform="cuda") +register_lowering(mpi_gather_p, mpi_gather_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/recv.py b/mpi4jax/_src/collective_ops/recv.py index ea5e1436..5b21490b 100644 --- a/mpi4jax/_src/collective_ops/recv.py +++ b/mpi4jax/_src/collective_ops/recv.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -21,7 +21,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -213,6 +213,6 @@ def mpi_recv_abstract_eval(xs, token, source, tag, comm, status): mpi_recv_p.def_effectful_abstract_eval(mpi_recv_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_xpu, platform="xpu") +register_lowering(mpi_recv_p, mpi_recv_xla_encode_cpu, platform="cpu") +register_lowering(mpi_recv_p, mpi_recv_xla_encode_cuda, platform="cuda") +register_lowering(mpi_recv_p, mpi_recv_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/reduce.py b/mpi4jax/_src/collective_ops/reduce.py index 54c7c9fe..667b7172 100644 --- a/mpi4jax/_src/collective_ops/reduce.py +++ b/mpi4jax/_src/collective_ops/reduce.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -210,6 +210,6 @@ def mpi_reduce_abstract_eval(xs, token, op, root, comm): mpi_reduce_p.def_effectful_abstract_eval(mpi_reduce_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_xpu, platform="xpu") +register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cpu, platform="cpu") +register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cuda, platform="cuda") +register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/scan.py b/mpi4jax/_src/collective_ops/scan.py index d8a6e49b..4d146fda 100644 --- a/mpi4jax/_src/collective_ops/scan.py +++ b/mpi4jax/_src/collective_ops/scan.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -178,6 +178,6 @@ def mpi_scan_abstract_eval(xs, token, op, comm): mpi_scan_p.def_effectful_abstract_eval(mpi_scan_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_xpu, platform="xpu") +register_lowering(mpi_scan_p, mpi_scan_xla_encode_cpu, platform="cpu") +register_lowering(mpi_scan_p, mpi_scan_xla_encode_cuda, platform="cuda") +register_lowering(mpi_scan_p, mpi_scan_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/scatter.py b/mpi4jax/_src/collective_ops/scatter.py index a665c108..934bfe19 100644 --- a/mpi4jax/_src/collective_ops/scatter.py +++ b/mpi4jax/_src/collective_ops/scatter.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -227,6 +227,6 @@ def mpi_scatter_abstract_eval(x, token, root, comm): mpi_scatter_p.def_effectful_abstract_eval(mpi_scatter_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_xpu, platform="xpu") +register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cpu, platform="cpu") +register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cuda, platform="cuda") +register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/send.py b/mpi4jax/_src/collective_ops/send.py index 1120fb07..8f36e44c 100644 --- a/mpi4jax/_src/collective_ops/send.py +++ b/mpi4jax/_src/collective_ops/send.py @@ -5,7 +5,7 @@ from jax.core import Primitive, Tracer, Token from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -20,7 +20,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type +from ..jax_compat import custom_call, register_lowering, token_type from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -164,6 +164,6 @@ def mpi_send_abstract_eval(xs, token, dest, tag, comm): mpi_send_p.def_effectful_abstract_eval(mpi_send_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_send_p, mpi_send_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_send_p, mpi_send_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_send_p, mpi_send_xla_encode_xpu, platform="xpu") +register_lowering(mpi_send_p, mpi_send_xla_encode_cpu, platform="cpu") +register_lowering(mpi_send_p, mpi_send_xla_encode_cuda, platform="cuda") +register_lowering(mpi_send_p, mpi_send_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/collective_ops/sendrecv.py b/mpi4jax/_src/collective_ops/sendrecv.py index 542c7eda..803f983e 100644 --- a/mpi4jax/_src/collective_ops/sendrecv.py +++ b/mpi4jax/_src/collective_ops/sendrecv.py @@ -6,7 +6,7 @@ from jax.interpreters import ad, batching from jax.lax import create_token -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from ..utils import ( @@ -22,7 +22,7 @@ effect, prefer_notoken, ) -from ..jax_compat import custom_call, token_type, ShapedArray +from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray from ..decorators import ( translation_rule_cpu, translation_rule_cuda, @@ -425,6 +425,6 @@ def mpi_sendrecv_transpose_rule( ad.primitive_transposes[mpi_sendrecv_p] = mpi_sendrecv_transpose_rule # assign to the primitive the correct encoder -mlir.register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_xpu, platform="xpu") +register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cpu, platform="cpu") +register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cuda, platform="cuda") +register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/_src/jax_compat.py b/mpi4jax/_src/jax_compat.py index 1775262b..24df5305 100644 --- a/mpi4jax/_src/jax_compat.py +++ b/mpi4jax/_src/jax_compat.py @@ -5,6 +5,7 @@ import jax import jaxlib +from jax.interpreters import mlir from jax.interpreters.mlir import token_type as jax_token_type, TokenSet @@ -47,6 +48,15 @@ def check_jax_version(): ) +def register_lowering(prim, rule, platform="cpu"): + try: + return mlir.register_lowering(prim, rule, platform=platform) + except NotImplementedError: + # Raised if the platform is supplied by a non-installed plugin + assert platform != "cpu" + return None + + # TODO: remove the other path once we require jax >= 0.4.31 if versiontuple(jax.__version__) >= (0, 4, 31): token_type = jax_token_type @@ -118,7 +128,6 @@ def register_effect(EffectType, ordered=False): EffectType = object def register_effect(EffectType, ordered=False): - from jax.interpreters import mlir from jax._src.lax import control_flow as lcf import jax._src.custom_derivatives as custom_derivatives diff --git a/mpi4jax/experimental/notoken/collective_ops/allgather.py b/mpi4jax/experimental/notoken/collective_ops/allgather.py index ad199355..05af3810 100644 --- a/mpi4jax/experimental/notoken/collective_ops/allgather.py +++ b/mpi4jax/experimental/notoken/collective_ops/allgather.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -197,6 +198,6 @@ def mpi_allgather_abstract_eval(x, comm): mpi_allgather_p.def_impl(mpi_allgather_impl) mpi_allgather_p.def_effectful_abstract_eval(mpi_allgather_abstract_eval) -mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu") +register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu") +register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cuda, platform="cuda") +register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/allreduce.py b/mpi4jax/experimental/notoken/collective_ops/allreduce.py index 94fcbd0b..81a9078e 100644 --- a/mpi4jax/experimental/notoken/collective_ops/allreduce.py +++ b/mpi4jax/experimental/notoken/collective_ops/allreduce.py @@ -4,7 +4,7 @@ from jax.core import Primitive from jax.interpreters import ad, batching -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -19,6 +19,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -231,6 +232,6 @@ def mpi_allreduce_transpose_rule(x_tan, *x_args, op, comm, transpose): ad.primitive_transposes[mpi_allreduce_p] = mpi_allreduce_transpose_rule # assign to the primitive the correct encoder -mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu") +register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu") +register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cuda, platform="cuda") +register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/alltoall.py b/mpi4jax/experimental/notoken/collective_ops/alltoall.py index 086a8aaf..4a3ecf89 100644 --- a/mpi4jax/experimental/notoken/collective_ops/alltoall.py +++ b/mpi4jax/experimental/notoken/collective_ops/alltoall.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -197,6 +198,6 @@ def mpi_alltoall_abstract_eval(xs, comm): mpi_alltoall_p.def_effectful_abstract_eval(mpi_alltoall_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu") +register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu") +register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cuda, platform="cuda") +register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/barrier.py b/mpi4jax/experimental/notoken/collective_ops/barrier.py index 5f7a1314..eb0b32b3 100644 --- a/mpi4jax/experimental/notoken/collective_ops/barrier.py +++ b/mpi4jax/experimental/notoken/collective_ops/barrier.py @@ -2,7 +2,7 @@ from mpi4py import MPI as _MPI from jax.core import Primitive -from jax.interpreters import batching, mlir +from jax.interpreters import batching from mpi4jax._src.utils import ( HashableMPIType, @@ -15,6 +15,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, get_token_effect, @@ -136,6 +137,6 @@ def mpi_barrier_batch_eval(in_args, batch_axes, comm): batching.primitive_batchers[mpi_barrier_p] = mpi_barrier_batch_eval # assign to the primitive the correct encoder -mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu") +register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu") +register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cuda, platform="cuda") +register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/bcast.py b/mpi4jax/experimental/notoken/collective_ops/bcast.py index 17eda5ed..7115f79e 100644 --- a/mpi4jax/experimental/notoken/collective_ops/bcast.py +++ b/mpi4jax/experimental/notoken/collective_ops/bcast.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -206,6 +207,6 @@ def mpi_bcast_abstract_eval(xs, root, comm): mpi_bcast_p.def_effectful_abstract_eval(mpi_bcast_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu") +register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu") +register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cuda, platform="cuda") +register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/gather.py b/mpi4jax/experimental/notoken/collective_ops/gather.py index 0028516f..155ee472 100644 --- a/mpi4jax/experimental/notoken/collective_ops/gather.py +++ b/mpi4jax/experimental/notoken/collective_ops/gather.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -237,6 +238,6 @@ def mpi_gather_abstract_eval(x, root, comm): mpi_gather_p.def_effectful_abstract_eval(mpi_gather_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_xpu, platform="xpu") +register_lowering(mpi_gather_p, mpi_gather_xla_encode_cpu, platform="cpu") +register_lowering(mpi_gather_p, mpi_gather_xla_encode_cuda, platform="cuda") +register_lowering(mpi_gather_p, mpi_gather_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/recv.py b/mpi4jax/experimental/notoken/collective_ops/recv.py index a71bb79f..fb3006d0 100644 --- a/mpi4jax/experimental/notoken/collective_ops/recv.py +++ b/mpi4jax/experimental/notoken/collective_ops/recv.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -19,6 +19,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -212,6 +213,6 @@ def mpi_recv_abstract_eval(xs, source, tag, comm, status): mpi_recv_p.def_effectful_abstract_eval(mpi_recv_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_xpu, platform="xpu") +register_lowering(mpi_recv_p, mpi_recv_xla_encode_cpu, platform="cpu") +register_lowering(mpi_recv_p, mpi_recv_xla_encode_cuda, platform="cuda") +register_lowering(mpi_recv_p, mpi_recv_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/reduce.py b/mpi4jax/experimental/notoken/collective_ops/reduce.py index 1a80b06b..ebd062ab 100644 --- a/mpi4jax/experimental/notoken/collective_ops/reduce.py +++ b/mpi4jax/experimental/notoken/collective_ops/reduce.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -212,6 +213,6 @@ def mpi_reduce_abstract_eval(xs, op, root, comm): mpi_reduce_p.def_effectful_abstract_eval(mpi_reduce_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_xpu, platform="xpu") +register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cpu, platform="cpu") +register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cuda, platform="cuda") +register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/scan.py b/mpi4jax/experimental/notoken/collective_ops/scan.py index b69dcac2..69c884da 100644 --- a/mpi4jax/experimental/notoken/collective_ops/scan.py +++ b/mpi4jax/experimental/notoken/collective_ops/scan.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -180,6 +181,6 @@ def mpi_scan_abstract_eval(xs, op, comm): mpi_scan_p.def_effectful_abstract_eval(mpi_scan_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_xpu, platform="xpu") +register_lowering(mpi_scan_p, mpi_scan_xla_encode_cpu, platform="cpu") +register_lowering(mpi_scan_p, mpi_scan_xla_encode_cuda, platform="cuda") +register_lowering(mpi_scan_p, mpi_scan_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/scatter.py b/mpi4jax/experimental/notoken/collective_ops/scatter.py index fe8285cb..8c8ffc1d 100644 --- a/mpi4jax/experimental/notoken/collective_ops/scatter.py +++ b/mpi4jax/experimental/notoken/collective_ops/scatter.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -227,6 +228,6 @@ def mpi_scatter_abstract_eval(x, root, comm): mpi_scatter_p.def_effectful_abstract_eval(mpi_scatter_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_xpu, platform="xpu") +register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cpu, platform="cpu") +register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cuda, platform="cuda") +register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/send.py b/mpi4jax/experimental/notoken/collective_ops/send.py index ec28a3ad..8b1d1cfd 100644 --- a/mpi4jax/experimental/notoken/collective_ops/send.py +++ b/mpi4jax/experimental/notoken/collective_ops/send.py @@ -3,7 +3,7 @@ from jax.core import Primitive -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -18,6 +18,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, get_token_effect, @@ -174,6 +175,6 @@ def mpi_send_abstract_eval(xs, dest, tag, comm): mpi_send_p.def_effectful_abstract_eval(mpi_send_abstract_eval) # assign to the primitive the correct encoder -mlir.register_lowering(mpi_send_p, mpi_send_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_send_p, mpi_send_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_send_p, mpi_send_xla_encode_xpu, platform="xpu") +register_lowering(mpi_send_p, mpi_send_xla_encode_cpu, platform="cpu") +register_lowering(mpi_send_p, mpi_send_xla_encode_cuda, platform="cuda") +register_lowering(mpi_send_p, mpi_send_xla_encode_xpu, platform="xpu") diff --git a/mpi4jax/experimental/notoken/collective_ops/sendrecv.py b/mpi4jax/experimental/notoken/collective_ops/sendrecv.py index 3c23b376..790344a6 100644 --- a/mpi4jax/experimental/notoken/collective_ops/sendrecv.py +++ b/mpi4jax/experimental/notoken/collective_ops/sendrecv.py @@ -4,7 +4,7 @@ from jax.core import Primitive, get_aval from jax.interpreters import ad, batching -from jax.interpreters import mlir + import jaxlib.mlir.ir as ir from mpi4jax._src.utils import ( @@ -20,6 +20,7 @@ ordered_effect, ) from mpi4jax._src.jax_compat import ( + register_lowering, custom_call, token_type, ShapedArray, @@ -453,6 +454,6 @@ def mpi_sendrecv_transpose_rule( ad.primitive_transposes[mpi_sendrecv_p] = mpi_sendrecv_transpose_rule # assign to the primitive the correct encoder -mlir.register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cpu, platform="cpu") -mlir.register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cuda, platform="cuda") -# mlir.register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_xpu, platform="xpu") +register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cpu, platform="cpu") +register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_cuda, platform="cuda") +register_lowering(mpi_sendrecv_p, mpi_sendrecv_xla_encode_xpu, platform="xpu")