Skip to content

Commit 08b530a

Browse files
authored
xpu is back on the menu boys (#266)
1 parent 1570a8e commit 08b530a

25 files changed

+129
-109
lines changed

mpi4jax/_src/collective_ops/allgather.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -199,6 +199,6 @@ def mpi_allgather_abstract_eval(x, token, comm):
199199
mpi_allgather_p.def_impl(mpi_allgather_impl)
200200
mpi_allgather_p.def_effectful_abstract_eval(mpi_allgather_abstract_eval)
201201

202-
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu")
203-
mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cuda, platform="cuda")
204-
# mlir.register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu")
202+
register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cpu, platform="cpu")
203+
register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_cuda, platform="cuda")
204+
register_lowering(mpi_allgather_p, mpi_allgather_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/allreduce.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from jax.interpreters import ad, batching
77
from jax.lax import create_token
88

9-
from jax.interpreters import mlir
9+
1010
import jaxlib.mlir.ir as ir
1111

1212
from ..utils import (
@@ -21,7 +21,7 @@
2121
effect,
2222
prefer_notoken,
2323
)
24-
from ..jax_compat import custom_call, token_type, ShapedArray
24+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2525
from ..decorators import (
2626
translation_rule_cpu,
2727
translation_rule_cuda,
@@ -232,6 +232,6 @@ def mpi_allreduce_transpose_rule(tan_args, *x_args, op, comm, transpose):
232232
ad.primitive_transposes[mpi_allreduce_p] = mpi_allreduce_transpose_rule
233233

234234
# assign to the primitive the correct encoder
235-
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu")
236-
mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cuda, platform="cuda")
237-
# mlir.register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu")
235+
register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cpu, platform="cpu")
236+
register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_cuda, platform="cuda")
237+
register_lowering(mpi_allreduce_p, mpi_allreduce_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/alltoall.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -198,6 +198,6 @@ def mpi_alltoall_abstract_eval(xs, token, comm):
198198
mpi_alltoall_p.def_effectful_abstract_eval(mpi_alltoall_abstract_eval)
199199

200200
# assign to the primitive the correct encoder
201-
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu")
202-
mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cuda, platform="cuda")
203-
# mlir.register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu")
201+
register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cpu, platform="cpu")
202+
register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_cuda, platform="cuda")
203+
register_lowering(mpi_alltoall_p, mpi_alltoall_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/barrier.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from jax.interpreters import batching
77
from jax.lax import create_token
88

9-
from jax.interpreters import mlir
109

1110
from ..utils import (
1211
HashableMPIType,
@@ -19,7 +18,7 @@
1918
effect,
2019
prefer_notoken,
2120
)
22-
from ..jax_compat import custom_call, token_type
21+
from ..jax_compat import custom_call, register_lowering, token_type
2322
from ..decorators import (
2423
translation_rule_cpu,
2524
translation_rule_cuda,
@@ -135,6 +134,6 @@ def mpi_barrier_batch_eval(in_args, batch_axes, comm):
135134
batching.primitive_batchers[mpi_barrier_p] = mpi_barrier_batch_eval
136135

137136
# assign to the primitive the correct encoder
138-
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu")
139-
mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cuda, platform="cuda")
140-
# mlir.register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu")
137+
register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cpu, platform="cpu")
138+
register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_cuda, platform="cuda")
139+
register_lowering(mpi_barrier_p, mpi_barrier_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/bcast.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -203,6 +203,6 @@ def mpi_bcast_abstract_eval(xs, token, root, comm):
203203
mpi_bcast_p.def_effectful_abstract_eval(mpi_bcast_abstract_eval)
204204

205205
# assign to the primitive the correct encoder
206-
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu")
207-
mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cuda, platform="cuda")
208-
# mlir.register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu")
206+
register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cpu, platform="cpu")
207+
register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_cuda, platform="cuda")
208+
register_lowering(mpi_bcast_p, mpi_bcast_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/gather.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -237,6 +237,6 @@ def mpi_gather_abstract_eval(x, token, root, comm):
237237
mpi_gather_p.def_effectful_abstract_eval(mpi_gather_abstract_eval)
238238

239239
# assign to the primitive the correct encoder
240-
mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_cpu, platform="cpu")
241-
mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_cuda, platform="cuda")
242-
# mlir.register_lowering(mpi_gather_p, mpi_gather_xla_encode_xpu, platform="xpu")
240+
register_lowering(mpi_gather_p, mpi_gather_xla_encode_cpu, platform="cpu")
241+
register_lowering(mpi_gather_p, mpi_gather_xla_encode_cuda, platform="cuda")
242+
register_lowering(mpi_gather_p, mpi_gather_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/recv.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -21,7 +21,7 @@
2121
effect,
2222
prefer_notoken,
2323
)
24-
from ..jax_compat import custom_call, token_type, ShapedArray
24+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2525
from ..decorators import (
2626
translation_rule_cpu,
2727
translation_rule_cuda,
@@ -213,6 +213,6 @@ def mpi_recv_abstract_eval(xs, token, source, tag, comm, status):
213213
mpi_recv_p.def_effectful_abstract_eval(mpi_recv_abstract_eval)
214214

215215
# assign to the primitive the correct encoder
216-
mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_cpu, platform="cpu")
217-
mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_cuda, platform="cuda")
218-
# mlir.register_lowering(mpi_recv_p, mpi_recv_xla_encode_xpu, platform="xpu")
216+
register_lowering(mpi_recv_p, mpi_recv_xla_encode_cpu, platform="cpu")
217+
register_lowering(mpi_recv_p, mpi_recv_xla_encode_cuda, platform="cuda")
218+
register_lowering(mpi_recv_p, mpi_recv_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/reduce.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -210,6 +210,6 @@ def mpi_reduce_abstract_eval(xs, token, op, root, comm):
210210
mpi_reduce_p.def_effectful_abstract_eval(mpi_reduce_abstract_eval)
211211

212212
# assign to the primitive the correct encoder
213-
mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cpu, platform="cpu")
214-
mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cuda, platform="cuda")
215-
# mlir.register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_xpu, platform="xpu")
213+
register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cpu, platform="cpu")
214+
register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_cuda, platform="cuda")
215+
register_lowering(mpi_reduce_p, mpi_reduce_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/scan.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -178,6 +178,6 @@ def mpi_scan_abstract_eval(xs, token, op, comm):
178178
mpi_scan_p.def_effectful_abstract_eval(mpi_scan_abstract_eval)
179179

180180
# assign to the primitive the correct encoder
181-
mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_cpu, platform="cpu")
182-
mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_cuda, platform="cuda")
183-
# mlir.register_lowering(mpi_scan_p, mpi_scan_xla_encode_xpu, platform="xpu")
181+
register_lowering(mpi_scan_p, mpi_scan_xla_encode_cpu, platform="cpu")
182+
register_lowering(mpi_scan_p, mpi_scan_xla_encode_cuda, platform="cuda")
183+
register_lowering(mpi_scan_p, mpi_scan_xla_encode_xpu, platform="xpu")

mpi4jax/_src/collective_ops/scatter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jax.core import Primitive, Tracer, Token
66
from jax.lax import create_token
77

8-
from jax.interpreters import mlir
8+
99
import jaxlib.mlir.ir as ir
1010

1111
from ..utils import (
@@ -20,7 +20,7 @@
2020
effect,
2121
prefer_notoken,
2222
)
23-
from ..jax_compat import custom_call, token_type, ShapedArray
23+
from ..jax_compat import custom_call, register_lowering, token_type, ShapedArray
2424
from ..decorators import (
2525
translation_rule_cpu,
2626
translation_rule_cuda,
@@ -227,6 +227,6 @@ def mpi_scatter_abstract_eval(x, token, root, comm):
227227
mpi_scatter_p.def_effectful_abstract_eval(mpi_scatter_abstract_eval)
228228

229229
# assign to the primitive the correct encoder
230-
mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cpu, platform="cpu")
231-
mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cuda, platform="cuda")
232-
# mlir.register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_xpu, platform="xpu")
230+
register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cpu, platform="cpu")
231+
register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_cuda, platform="cuda")
232+
register_lowering(mpi_scatter_p, mpi_scatter_xla_encode_xpu, platform="xpu")

0 commit comments

Comments
 (0)