deps: Update jax to >=0.6.0 #331
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
>=0.5.3
->>=0.6.0
>=0.5.3
->>=0.6.0
Release Notes
jax-ml/jax (jax)
v0.6.0
Compare Source
Breaking changes
jax.numpy.array
no longer acceptsNone
. This behavior wasdeprecated since November 2023 and is now removed.
config.jax_data_dependent_tracing_fallback
config option,which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery.
config.jax_eager_pmap
config option.lower
andtrace
AOT APIs on the resultof
jax.jit
if there have been subsequent wrappers applied.Previously this worked, but silently ignored the wrappers.
The workaround is to apply
jax.jit
last among the wrappers,and similarly for
jax.pmap
.See {jax-issue}
#27873
.cuda12_pip
extra forjax
has been removed; usepip install jax[cuda12]
instead.
Changes
supported.
align with PEP 685. For instance, if you were previously using
pip install jax[cuda12_local]
to install JAX, run
pip install jax[cuda12-local]
instead.jax.jit
now requiresfun
to be passed by position, and additionalarguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
Deprecations
jax.tree_util.build_tree
is deprecated. Use {func}jax.tree.unflatten
instead.
and removed existing CPU/GPU handlers using XLA's custom call.
jax.lib.xla_extension
are now deprecated.jax.interpreters.mlir.hlo
andjax.interpreters.mlir.func_dialect
,which were accidental exports, have been removed. If needed, they are
available from
jax.extend.mlir
.jax.interpreters.mlir.custom_call
is deprecated. The APIs provided by{mod}
jax.ffi
should be used instead.jax.ffi.ffi_call
with inline arguments is nolonger supported. {func}
~jax.ffi.ffi_call
now unconditionally returns acallable.
jax.lib.xla_client
are deprecated:get_topology_for_devices
,heap_profile
,mlir_api_version
,Client
,CompileOptions
,DeviceAssignment
,Frame
,HloSharding
,OpSharding
,Traceback
.jax.util
are deprecated:HashableFunction
,as_hashable_function
,cache
,safe_map
,safe_zip
,split_dict
,split_list
,split_list_checked
,split_merge
,subvals
,toposort
,unzip2
,wrap_name
, andwraps
.jax.dlpack.to_dlpack
has been deprecated. You can usually pass a JAXArray
directly to thefrom_dlpack
function of another framework. If youneed the functionality of
to_dlpack
, use the__dlpack__
attribute of anarray.
jax.lax.infeed
,jax.lax.infeed_p
,jax.lax.outfeed
, andjax.lax.outfeed_p
are deprecated and will be removed in JAX v0.7.0.jax.lib.xla_client
:ArrayImpl
,FftType
,PaddingType
,PrimitiveType
,XlaBuilder
,dtype_to_etype
,ops
,register_custom_call_target
,shape_from_pyval
,Shape
,XlaComputation
.jax.lib.xla_extension
:ArrayImpl
,XlaRuntimeError
.jax
:jax.treedef_is_leaf
,jax.tree_flatten
,jax.tree_map
,jax.tree_leaves
,jax.tree_structure
,jax.tree_transpose
, andjax.tree_unflatten
. Replacements can be found in {mod}jax.tree
or{mod}
jax.tree_util
.jax.core
:AxisSize
,ClosedJaxpr
,EvalTrace
,InDBIdx
,InputType
,Jaxpr
,JaxprEqn
,Literal
,MapPrimitive
,OpaqueTraceState
,OutDBIdx
,Primitive
,Token
,TRACER_LEAK_DEBUGGER_WARNING
,Var
,concrete_aval
,dedup_referents
,escaped_tracer_error
,extend_axis_env_nd
,full_lower
,get_referent
,jaxpr_as_fun
,join_effects
,lattice_join
,leaked_tracer_error
,maybe_find_leaked_tracers
,raise_to_shaped
,raise_to_shaped_mappings
,reset_trace_state
,str_eqn_compact
,substitute_vars_in_output_ty
,typecompat
, andused_axis_names_jaxpr
. Mosthave no public replacement, though a few are available at {mod}
jax.extend.core
.vectorized
argument to {func}~jax.pure_callback
and{func}
~jax.ffi.ffi_call
. Use thevmap_method
parameter instead.Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Enabled.
♻ Rebasing: Whenever PR is behind base branch, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about these updates again.
This PR was generated by Mend Renovate. View the repository job log.