mpi4jax with non static argument? #268
-
Hello, from mpi4py import MPI
import jax, mpi4jax
@jax.jit
def get_pi_part(n_intervals=10000, rank=0, size=1):
h = 1.0/n_intervals # width of each interval
partial_sum = 0.0 # initialize the partial sum
for i in range(rank+1, n_intervals, size): # loop over the intervals
x = h*(i-0.5) # x value at the center of the interval
partial_sum += 4.0/(1.0 + x**2) # add the height of the rectangle to the partial sum
return h*partial_sum # return the partial sum
comm = MPI.COMM_WORLD
@jax.jit
def pi_mpi4jax(n_intervals=10000):
part = get_pi_part(n_intervals, MPI.COMM_WORLD.rank, MPI.COMM_WORLD.size)
pi = mpi4jax.reduce(part, op=MPI.SUM, root=0)
return pi[0]
pi_result = pi_mpi4jax() However, I obtain the following error:
If I directly compute get_pi_part() without using pi_mpi4jax(), it works well. But of course, I would like to use mpi4jax to scale my code on several nodes. Such script works well when adapted to mpi4py or numba_mpi for example. Any idea of a good practice is such cases where I want to pass dynamic arguments? The JAX documentation is quite hard to follow... Thank you! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I think the problem in your script has to do with the fact that n intervals must be static (because you have a for loop depending on it). In This is not related to mpi4jax |
Beta Was this translation helpful? Give feedback.
I think the problem in your script has to do with the fact that n intervals must be static (because you have a for loop depending on it).
In
get_pi_part
you are saved because you do not specify it, and the default value is treated somehow as static, but if you were to specify it you'd see that your code would break.This is not related to mpi4jax