Skip to content

Compile with pip Nvidia packages... #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 7, 2024
Merged

Compile with pip Nvidia packages... #236

merged 4 commits into from
Jun 7, 2024

Conversation

PhilipVinc
Copy link
Member

No description provided.

Copy link

codecov bot commented Jun 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 71.45%. Comparing base (0e70789) to head (55e8469).
Report is 4 commits behind head on master.

Current head 55e8469 differs from pull request most recent head f5bb285

Please upload reports for the commit f5bb285 to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #236      +/-   ##
==========================================
+ Coverage   70.21%   71.45%   +1.24%     
==========================================
  Files          34       34              
  Lines        2078     2172      +94     
  Branches      157      164       +7     
==========================================
+ Hits         1459     1552      +93     
+ Misses        559      558       -1     
- Partials       60       62       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@PhilipVinc
Copy link
Member Author

Installing this branch with

pip install --no-build-isolation --verbose git+https://github.com/mpi4jax/mpi4jax.git@pv/test-install

works and correctly picks up the Nvidia packages if they are installed.

However one then needs to set the LD_LIBRARY_PATH because the path is not hardcoded... with something like

export LD_LIBRARY_PATH=[...]/ENV_NAME/lib/python3.11/site-packages/nvidia/cuda_runtime/lib:$LD_LIBRARY_PATH

Nevertheless, this still gives raise to some annoying ptas errors..

@PhilipVinc
Copy link
Member Author

Though it works for standard code

@PhilipVinc
Copy link
Member Author

But this is nice. It allows to avoid requiring CUDA installed....

@PhilipVinc
Copy link
Member Author

This now sets the rpath so it bakes the compile-time libraries if it detects it's obtained from pip...

@PhilipVinc
Copy link
Member Author

@dionhaefner what can we do to package this properly?

@dionhaefner
Copy link
Collaborator

This looks generally fine to me.

Setting rpath means that mpi4jax may break violently if the user ever updates the CUDA libs but I guess it's the best we can do, barring using dlopen at runtime like XLA does.

I think the most important task now is to test all modes of installation on CI. That should give us enough confidence to ship this new feature pretty much as-is.

Wip

WIP

WIP

WIP

WIP

fix

a

a

a

rpath

cleanup

documentation

nit

add test build

fix

k

123

dsa

asd

asd

as

f

f
@PhilipVinc
Copy link
Member Author

@dionhaefner I added a test for the build infrastructure using PyPi distributed cuda.

I also added the relevant instructions and a mention that [cuda12] is not yet supported.

Anything else?

@PhilipVinc PhilipVinc marked this pull request as ready for review June 6, 2024 07:46
@PhilipVinc PhilipVinc requested a review from dionhaefner June 6, 2024 08:10
@PhilipVinc
Copy link
Member Author

Is it fine for you if I merge and tag a new release?

There are 3 ways to install jax with CUDA support:
- using a pypi-distributed CUDA installation (suggested by jax developers) ``pip install -U 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html``
- using the locally-installed CUDA version, which must be compatible with jax. ``pip install -U 'jax[cuda12_local]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html``
The procedure to install ``mpi4jax`` for the two situations is different.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could recommend --no-build-isolation always for CUDA builds? Can't come up with a case where this would hurt to have.

@dionhaefner
Copy link
Collaborator

Fine with me, I could pick some nits with the logic in setup.py and the wording in the docs, but feels like this is up for a refactor soon anyway once we have to support jax[cuda12] (when jax[cuda12_pip] goes away).

@PhilipVinc
Copy link
Member Author

yes, that's what I was thinking. It seems that the issues with supporting jax[cuda12] are not up to us, but because the present interface based on the pjrt runtime does not support defining custom calls from python.

They are fixing it on jax those days. I guess I'll try again in a week or two.

@PhilipVinc PhilipVinc merged commit 7829951 into master Jun 7, 2024
@PhilipVinc PhilipVinc deleted the pv/test-install branch June 7, 2024 14:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants