-
Notifications
You must be signed in to change notification settings - Fork 32
Open
Description
Hello, ,I have a scientific workflow where I need to be able to interface C++ with JAX. The C++ code needs to pass data, ideally by directly setting device array memory in a manner that allows JAX to use the array without performing a copy. It would look something like this psuedocode:
# in python
create_input_device_arrays
create_result_arrays
tell_c++_and_wait
# In C++
fill_input_arrays
tell_python_and_wait
# in python
operate_on_arrays
set_result_arrays
tell_c++_and_wait
Could mpi4jax handle this flow, or accomplish it in another way? Looking for solutions that would work for both GPU and CPU devices. Any suggestions from the community would be welcome.
Metadata
Metadata
Assignees
Labels
No labels