Skip to content

Fixes for cuda pjrt plugin #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 12, 2024
Merged

Fixes for cuda pjrt plugin #241

merged 5 commits into from
Jun 12, 2024

Conversation

PhilipVinc
Copy link
Member

Fixes #235

See the discussion in jax-ml/jax#21807

In short, the pjrt client only supports standard strings, but our Cython code returns a bytes string.

This will be fixed in a future jax version, but in the meantime, we can ship the fix.

@PhilipVinc
Copy link
Member Author

Now that this works, we could even do like Jax does, and dlopen xla_bridge_cuda only when needed (essentially when an user binds a cuda primitive).

By dlopening it, we could get around the linking problem, as we could have Nvidia wheels as a compile-time dependency for the headers, but then only load the .so when needed. Then, we could do as jax does and set the rpath to the parent folder of the package, which we know should contain Nvidia headers.

This would basically solve the CUDA linking problem and make it as simple to install mpi4jax+cuda without having to use --no-build-isolation.

To support installations with 'local' cuda, the dlopen mechanism would automatically pickup LD_LIBRARY_PATH if nothing is in the RPATH (again, this is the scheme used by jax). As the user must already set LD_LIBRARY PATH, this would work out well.

On possible way to ship this would be to mimic jax's extras, and have mpi4jax be mpi4jax[cpu], and then also have mpi4jax[cuda] and mpi4jax[rocm] and mpi4jax[sycl]...

I am unsure, but surely finding a way to simplifying the installation would be great.

@dionhaefner
Copy link
Collaborator

That sounds doable with some work, but is this really that complicated?

pip install jax[cuda] mpi4py cython
pip install mpi4jax --no-build-isolation

@PhilipVinc
Copy link
Member Author

PhilipVinc commented Jun 12, 2024

Well, the problem is that you cannot ship a requirements file, instead you must share a script to setup the environment.
This breaks a bit 'simple to setup' reproducibility.

It would be better if we could just do pip install jax[cuda] mpi4jax[cuda] .

I agree that it's a lot of work for not much, though. So it's just something that will stay in the backlog until who knows when...

@PhilipVinc PhilipVinc requested a review from dionhaefner June 12, 2024 09:49
@dionhaefner
Copy link
Collaborator

@PhilipVinc does that work?

@PhilipVinc
Copy link
Member Author

Yes. Tested and it works.

@PhilipVinc PhilipVinc merged commit 2593655 into master Jun 12, 2024
@PhilipVinc PhilipVinc deleted the pv/fixes branch June 12, 2024 15:04
@dionhaefner
Copy link
Collaborator

Well, the problem is that you cannot ship a requirements file, instead you must share a script to setup the environment. This breaks a bit 'simple to setup' reproducibility.

It would be better if we could just do pip install jax[cuda] mpi4jax[cuda] .

I agree that it's a lot of work for not much, though. So it's just something that will stay in the backlog until who knows when...

Actually there's may be another option: depend on the CUDA runtime both at build time and at runtime. That way you can link to them during the build, set the rpath to site packages and rely on them to be present (since they're also a runtime dependency). Not the pretties solution but this should work out of the box.

@PhilipVinc
Copy link
Member Author

Yea, but this must be something the user opts in to. ideally it should be exposed as an extra. But build time dependencies can only be declared before we process extras requires.

the only solution I see is to do like Jax, and have a second package with the correct build time dependencies, which is required only if the extra is specified.

but that’s quite a bit extra infrastructure..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mpi4jax + jax[cuda12]
2 participants