Hybrid MPI program with JAX #262
-
Hi, I'm using However, I'm not sure how to launch the |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 1 reply
-
If you are using CPUs, my traditional suggestion is to assume that jax uses effectively only 1 or 2 CPUs, and therefore to launch with You should use also task set to ensure that the kernel does not move the processes around, unless it's already set by your HPC center. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the reply! I still have few curiosities.
Do you think this is this true even if the function to be applied to the data chunks is very expensive and uses
I tried to do this using MPI "map-by" and "bind-to" options, but it didn't work. Do you have any examples of how to do this? Finally, for the single program multiple data problem, do you think that Thanks again! |
Beta Was this translation helpful? Give feedback.
-
In my experience yes. The multithreading inside of jax jit performs less well than the MPI model. Of course this is only true if your algorithm scales well.
I never do it because on clusters I work on, it's automatically done by slurm.
On GPUs, in my codes, sharding outperforms by 10% mpi4jax direct gpu operation. I have no idea why. It's probably because NCCL is better tuned than MPI. Still, it's what we see. On CPUs, well, jax may use MPI or GLOO as a communication backed for sharding. MPI as a sharding backend works, but it's largely undocumented and hard to setup. I am partly responsible for getting the support there, so we have tiny docs in here https://netket.readthedocs.io/en/latest/docs/parallelization.html#mpitrampoline-backend-very-experimental . You need to compile MPITrampoline and launch jax through it. It's a mess. But performance is identical to using mpi4jax and you get the niceties of sharding (modulo some unsupported operations). So in essence mpi4jax performs identical to MPI sharding backend, though it's easier to setup and write code for. Sometimes sharding breaks down and starts to replicate the calculations unless you use shard map. mpi4jax is equivalent to putting a shard map on the whole of your code so it's a bit more 'reliable'. |
Beta Was this translation helpful? Give feedback.
-
Thanks so much! |
Beta Was this translation helpful? Give feedback.
In my experience yes. The multithreading inside of jax jit performs less well than the MPI model. Of course this is only true if your algorithm scales well.
I never do it because on clusters I work on, it's automatically done by slurm.
On GPUs, in my code…