Skip to content

Commit 2593655

Browse files
Fixes for cuda pjrt plugin (#241)
* fixes for cuda pjrt plugin * format * always decode * kernel names shouldn't be bytes to begin with * more byte names --------- Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
1 parent fdbef76 commit 2593655

File tree

3 files changed

+36
-36
lines changed

3 files changed

+36
-36
lines changed

mpi4jax/_src/xla_bridge/mpi_xla_bridge_cpu.pyx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ cdef void mpi_sendrecv_cpu(void** out_ptr, void** data_ptr) nogil:
198198
)
199199

200200

201-
declare_custom_call_target(b"mpi_allgather", <void*>(mpi_allgather_cpu))
202-
declare_custom_call_target(b"mpi_allreduce", <void*>(mpi_allreduce_cpu))
203-
declare_custom_call_target(b"mpi_alltoall", <void*>(mpi_alltoall_cpu))
204-
declare_custom_call_target(b"mpi_barrier", <void*>(mpi_barrier_cpu))
205-
declare_custom_call_target(b"mpi_bcast", <void*>(mpi_bcast_cpu))
206-
declare_custom_call_target(b"mpi_gather", <void*>(mpi_gather_cpu))
207-
declare_custom_call_target(b"mpi_recv", <void*>(mpi_recv_cpu))
208-
declare_custom_call_target(b"mpi_reduce", <void*>(mpi_reduce_cpu))
209-
declare_custom_call_target(b"mpi_scan", <void*>(mpi_scan_cpu))
210-
declare_custom_call_target(b"mpi_scatter", <void*>(mpi_scatter_cpu))
211-
declare_custom_call_target(b"mpi_send", <void*>(mpi_send_cpu))
212-
declare_custom_call_target(b"mpi_sendrecv", <void*>(mpi_sendrecv_cpu))
201+
declare_custom_call_target("mpi_allgather", <void*>(mpi_allgather_cpu))
202+
declare_custom_call_target("mpi_allreduce", <void*>(mpi_allreduce_cpu))
203+
declare_custom_call_target("mpi_alltoall", <void*>(mpi_alltoall_cpu))
204+
declare_custom_call_target("mpi_barrier", <void*>(mpi_barrier_cpu))
205+
declare_custom_call_target("mpi_bcast", <void*>(mpi_bcast_cpu))
206+
declare_custom_call_target("mpi_gather", <void*>(mpi_gather_cpu))
207+
declare_custom_call_target("mpi_recv", <void*>(mpi_recv_cpu))
208+
declare_custom_call_target("mpi_reduce", <void*>(mpi_reduce_cpu))
209+
declare_custom_call_target("mpi_scan", <void*>(mpi_scan_cpu))
210+
declare_custom_call_target("mpi_scatter", <void*>(mpi_scatter_cpu))
211+
declare_custom_call_target("mpi_send", <void*>(mpi_send_cpu))
212+
declare_custom_call_target("mpi_sendrecv", <void*>(mpi_sendrecv_cpu))

mpi4jax/_src/xla_bridge/mpi_xla_bridge_cuda.pyx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -715,15 +715,15 @@ cdef void mpi_sendrecv_cuda(cudaStream_t stream, void** buffers,
715715

716716

717717

718-
declare_custom_call_target(b"mpi_allgather", <void*>(mpi_allgather_cuda))
719-
declare_custom_call_target(b"mpi_allreduce", <void*>(mpi_allreduce_cuda))
720-
declare_custom_call_target(b"mpi_alltoall", <void*>(mpi_alltoall_cuda))
721-
declare_custom_call_target(b"mpi_barrier", <void*>(mpi_barrier_cuda))
722-
declare_custom_call_target(b"mpi_bcast", <void*>(mpi_bcast_cuda))
723-
declare_custom_call_target(b"mpi_gather", <void*>(mpi_gather_cuda))
724-
declare_custom_call_target(b"mpi_recv", <void*>(mpi_recv_cuda))
725-
declare_custom_call_target(b"mpi_reduce", <void*>(mpi_reduce_cuda))
726-
declare_custom_call_target(b"mpi_scan", <void*>(mpi_scan_cuda))
727-
declare_custom_call_target(b"mpi_scatter", <void*>(mpi_scatter_cuda))
728-
declare_custom_call_target(b"mpi_send", <void*>(mpi_send_cuda))
729-
declare_custom_call_target(b"mpi_sendrecv", <void*>(mpi_sendrecv_cuda))
718+
declare_custom_call_target("mpi_allgather", <void*>(mpi_allgather_cuda))
719+
declare_custom_call_target("mpi_allreduce", <void*>(mpi_allreduce_cuda))
720+
declare_custom_call_target("mpi_alltoall", <void*>(mpi_alltoall_cuda))
721+
declare_custom_call_target("mpi_barrier", <void*>(mpi_barrier_cuda))
722+
declare_custom_call_target("mpi_bcast", <void*>(mpi_bcast_cuda))
723+
declare_custom_call_target("mpi_gather", <void*>(mpi_gather_cuda))
724+
declare_custom_call_target("mpi_recv", <void*>(mpi_recv_cuda))
725+
declare_custom_call_target("mpi_reduce", <void*>(mpi_reduce_cuda))
726+
declare_custom_call_target("mpi_scan", <void*>(mpi_scan_cuda))
727+
declare_custom_call_target("mpi_scatter", <void*>(mpi_scatter_cuda))
728+
declare_custom_call_target("mpi_send", <void*>(mpi_send_cuda))
729+
declare_custom_call_target("mpi_sendrecv", <void*>(mpi_sendrecv_cuda))

mpi4jax/_src/xla_bridge/mpi_xla_bridge_xpu.pyx

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -702,15 +702,15 @@ cdef void mpi_sendrecv_xpu(void* stream, void** buffers,
702702
checked_sycl_memcpy(xqueue, out_buf, recvbuf , bytes_recv, comm)
703703
free(recvbuf)
704704

705-
declare_custom_call_target(b"mpi_allgather", <void*>(mpi_allgather_xpu))
706-
declare_custom_call_target(b"mpi_allreduce", <void*>(mpi_allreduce_xpu))
707-
declare_custom_call_target(b"mpi_alltoall", <void*>(mpi_alltoall_xpu))
708-
declare_custom_call_target(b"mpi_barrier", <void*>(mpi_barrier_xpu))
709-
declare_custom_call_target(b"mpi_bcast", <void*>(mpi_bcast_xpu))
710-
declare_custom_call_target(b"mpi_gather", <void*>(mpi_gather_xpu))
711-
declare_custom_call_target(b"mpi_recv", <void*>(mpi_recv_xpu))
712-
declare_custom_call_target(b"mpi_reduce", <void*>(mpi_reduce_xpu))
713-
declare_custom_call_target(b"mpi_scan", <void*>(mpi_scan_xpu))
714-
declare_custom_call_target(b"mpi_scatter", <void*>(mpi_scatter_xpu))
715-
declare_custom_call_target(b"mpi_send", <void*>(mpi_send_xpu))
716-
declare_custom_call_target(b"mpi_sendrecv", <void*>(mpi_sendrecv_xpu))
705+
declare_custom_call_target("mpi_allgather", <void*>(mpi_allgather_xpu))
706+
declare_custom_call_target("mpi_allreduce", <void*>(mpi_allreduce_xpu))
707+
declare_custom_call_target("mpi_alltoall", <void*>(mpi_alltoall_xpu))
708+
declare_custom_call_target("mpi_barrier", <void*>(mpi_barrier_xpu))
709+
declare_custom_call_target("mpi_bcast", <void*>(mpi_bcast_xpu))
710+
declare_custom_call_target("mpi_gather", <void*>(mpi_gather_xpu))
711+
declare_custom_call_target("mpi_recv", <void*>(mpi_recv_xpu))
712+
declare_custom_call_target("mpi_reduce", <void*>(mpi_reduce_xpu))
713+
declare_custom_call_target("mpi_scan", <void*>(mpi_scan_xpu))
714+
declare_custom_call_target("mpi_scatter", <void*>(mpi_scatter_xpu))
715+
declare_custom_call_target("mpi_send", <void*>(mpi_send_xpu))
716+
declare_custom_call_target("mpi_sendrecv", <void*>(mpi_sendrecv_xpu))

0 commit comments

Comments
 (0)