Skip to content

Could mpi4jax be used to effectively interface C++? #265

@Joshuaalbert

Description

@Joshuaalbert

Hello, ,I have a scientific workflow where I need to be able to interface C++ with JAX. The C++ code needs to pass data, ideally by directly setting device array memory in a manner that allows JAX to use the array without performing a copy. It would look something like this psuedocode:

# in python
create_input_device_arrays
create_result_arrays
tell_c++_and_wait
# In C++
fill_input_arrays
tell_python_and_wait
# in python
operate_on_arrays
set_result_arrays
tell_c++_and_wait

Could mpi4jax handle this flow, or accomplish it in another way? Looking for solutions that would work for both GPU and CPU devices. Any suggestions from the community would be welcome.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions