-
Notifications
You must be signed in to change notification settings - Fork 32
Extending mpi4jax with XPU support #226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Added inital sycl infrastructure
Jczaja/sycl except
- Fix to sendrecv problem
Remove redundant export
@dionhaefner , @PhilipVinc Anyway , We are looking forward to your feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This looks high-quality, let's fix some minor issues.
In case this breaks in the future: Who is going to maintain this code?
@dionhaefner Thanks for reviewing this PR. This functionality will be maintained by Intel JAX GPU teams . I will send some more info via email We started to work on suggestions you made and will update PR as soon as we have something. |
Jczaja/xpu support setup
@dionhaefner , @PhilipVinc We adressed (hopefully) all the questions and comments and updated PR with relevant changes. So we are ready for second round of review. Please review! |
Thanks, I like the reduced code duplication. Let's get some final changes done then merge this. |
@jczaja Could you please address the last 3 comments so we can get this merged? |
@dionhaefner I apologize for lack of activity from our side. The reason were explained in an email I sent some time ago. Let my introduce my colleague @Zantares who will take over maintaining this work. |
Hi @dionhaefner , we have taken related work and will give the feedback in 1 week, thanks for your patience! |
@dionhaefner we have addressed the left 3 comments. could you help to review? Thank you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks all!
@PhilipVinc You want to give it a final pass?
Hi,
This PR is a draft of changes enabling Intel XPU (via SYCL) . Changes are modeled after GPU (Cuda) integration so hopefully it should be easy to understand.
Unit tests pass rate:
Currently on our setup with XPU We got:
which match result when We run on purely CPU (test_common.py I got failing for some reason on both XPU and CPU execution)
Usage:
New env var was added MPI4JAX_USE_SYCL_MPI by analogy to MPI4JAX_USE_CUDA_MPI. MPI4JAX_USE_SYCL_MPI=1 makes mpi4jax to assume that MPI implementation used can work with XPU buffers (GPU/XPU aware) and MPI4JAX_USE_SYCL_MPI=0 makes mpi4jax to assume that MPI implementation cannot work XPU buffers and there is a copy of XPU buffer into allocated CPU buffer (analogically like it is done for CUDA).
TODO:
Concerns:
Any suggestions are welcomed!
@wozna, @sfraczek, @bartekkuncer, @shssf