Replies: 2 comments 11 replies
-
I may be missing some complexities here, but this sounds like it could be done a lot simpler with |
Beta Was this translation helpful? Give feedback.
-
Thanks for the help! As you said, for my case |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
First of all, I don't have much experience with MPI, so I would like to apologize beforehand for my lack of jargon and knowledge.
I am working on an optimization code written with JAX. We want to scale up the optimization problem we try to solve but the current state of the code only allows us to use a single GPU/CPU, so we are limited by memory on the GPU during the Jacobian calculation. I managed to distribute objectives to different GPUs and we are able to run bigger problems. However, my approach was to run different functions that have pre-distributed data on different devices by
jax.jit
with device argument in a for loop. I want to make the for loop parallel usingmpi4jax
.Here is a very simplified minimal working example of what I am trying to do. Sorry, for the lengthy code but I would like to keep the main structure of the code for a bunch of other non-mpi-related design choices. The purpose of my question is specifically the
jac_error
method of theObjectiveFunctionParallel
class.mpi-parallel-test-case.pdf
This is the jupyter notebook version
mpi-parallel-test-case.zip
To give a shorter review of what the code is doing,
Objective
classes (in the actual code we have over 60 of them with different output dimensions)ObjectiveFunction
class that has wrappers for the Jacobian and compute functions of each objective. The problem is even if we form the Jacobian with bunch of singlejvp
s (withoutvmap
ping over them) for large scale problems, we get out of memory error.What I have done this far,
ObjectiveFunctionParallel
class that calls eachObjective
s compute and Jacobian methods by placing the data on required GPU.jax.jit
the methods depending on theObjective
s device idI would like to create an MPI communication when I need to take the Jacobian (otherwise I need to put many
if rank == 0:
conditions which is not possible for the size of our code), execute the part I have a for loop in parallel, and then close the MPI communication. All the data needed by each GPU is already stored on that GPU, so, the only data transfer is needed at the end to form the full Jacobian. We typically need to take the Jacobian around 200 times per optimization, so, I need to be able to open and close the communication to prevent computing everything in the code multiple times on multiple processes.I think this is possible with
mpi4py
using Dynamic Process Management but I wanted to ask if this is doable inmpi4jax
? Also, I would appreciate any feedback about the implementation I have in mind!I am sorry for not being able to give a shorter explanation and MWE.
Best regards,
Beta Was this translation helpful? Give feedback.
All reactions