Skip to content

Commit ba8e3a7

Browse files
Bump jax from 0.6.2 to 0.7.0 in /mpi4jax/_src (#283)
* Bump jax from 0.6.2 to 0.7.0 in /mpi4jax/_src Bumps [jax](https://github.com/jax-ml/jax) from 0.6.2 to 0.7.0. - [Release notes](https://github.com/jax-ml/jax/releases) - [Changelog](https://github.com/jax-ml/jax/blob/main/CHANGELOG.md) - [Commits](jax-ml/jax@jax-v0.6.2...jax-v0.7.0) --- updated-dependencies: - dependency-name: jax dependency-version: 0.7.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * instantiate zeros * bump python version for coverage test --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Dion Häfner <mail@dionhaefner.de>
1 parent a06d78a commit ba8e3a7

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

.github/workflows/covecov-coverage.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ jobs:
1515
fail-fast: false
1616

1717
matrix:
18-
python-version: ["3.10"]
18+
python-version: ["3.13"]
1919
os: [ubuntu-latest]
20+
mpi: [openmpi]
2021

2122
env:
2223
MPICH_INTERFACE_HOSTNAME: localhost

mpi4jax/_src/_latest_jax_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# latest tested JAX version
22
# bumped automatically by dependabot
3-
jax==0.6.2
3+
jax==0.7.0

mpi4jax/_src/collective_ops/allreduce.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def mpi_allreduce_transpose_rule(x_tan, *x_args, op, comm, transpose):
223223
raise NotImplementedError(
224224
"The linear transpose of allreduce is only defined for op=MPI.SUM"
225225
)
226+
x_tan = ad.instantiate_zeros(x_tan)
226227
res = mpi_allreduce_p.bind(x_tan, op=op, comm=comm, transpose=(not transpose))
227228
return (res,)
228229

0 commit comments

Comments
 (0)