-
Notifications
You must be signed in to change notification settings - Fork 32
Fixes for cuda pjrt plugin #241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Now that this works, we could even do like Jax does, and By dlopening it, we could get around the linking problem, as we could have Nvidia wheels as a compile-time dependency for the headers, but then only load the .so when needed. Then, we could do as jax does and set the rpath to the parent folder of the package, which we know should contain Nvidia headers. This would basically solve the CUDA linking problem and make it as simple to install mpi4jax+cuda without having to use To support installations with 'local' cuda, the dlopen mechanism would automatically pickup On possible way to ship this would be to mimic jax's extras, and have I am unsure, but surely finding a way to simplifying the installation would be great. |
That sounds doable with some work, but is this really that complicated? pip install jax[cuda] mpi4py cython
pip install mpi4jax --no-build-isolation |
Well, the problem is that you cannot ship a requirements file, instead you must share a script to setup the environment. It would be better if we could just do I agree that it's a lot of work for not much, though. So it's just something that will stay in the backlog until who knows when... |
@PhilipVinc does that work? |
Yes. Tested and it works. |
Actually there's may be another option: depend on the CUDA runtime both at build time and at runtime. That way you can link to them during the build, set the rpath to site packages and rely on them to be present (since they're also a runtime dependency). Not the pretties solution but this should work out of the box. |
Yea, but this must be something the user opts in to. ideally it should be exposed as an extra. But build time dependencies can only be declared before we process extras requires. the only solution I see is to do like Jax, and have a second package with the correct build time dependencies, which is required only if the extra is specified. but that’s quite a bit extra infrastructure.. |
Fixes #235
See the discussion in jax-ml/jax#21807
In short, the pjrt client only supports standard strings, but our Cython code returns a bytes string.
This will be fixed in a future jax version, but in the meantime, we can ship the fix.