Skip to content

Hybrid MPI program with JAX #262

Closed Answered by PhilipVinc
mtagliazucchi asked this question in Q&A
Discussion options

You must be logged in to vote

Do you think this is this true even if the function to be applied to the data chunks is very expensive and uses jax.jit and jax.vmap?

In my experience yes. The multithreading inside of jax jit performs less well than the MPI model. Of course this is only true if your algorithm scales well.

I tried to do this using MPI "map-by" and "bind-to" options, but it didn't work. Do you have any examples of how to do this?

I never do it because on clusters I work on, it's automatically done by slurm.

Finally, for the single program multiple data problem, do you think that jax.sharding and mpi4jax generally have similar performance or not? Which of the two is more efficient?

On GPUs, in my code…

Replies: 4 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@PhilipVinc
Comment options

Answer selected by dionhaefner
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants