JAX is changing the default jax.pmap implementation #32412
danielsuo
announced in
Announcements
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
tl;dr
As of JAX 0.8.0, the default implementation of
jax.pmap
will be based onjax.jit
andjax.shard_map
. The new implementation is not a perfect replacement for the original. We've published documentation for this change to help users who run into trouble.This change makes
jax.pmap
integrate well with JAX shardings and simplifies the implementation.Help! Fix me now!
IMPORTANT: This option is not a permanent fix. Until January 15, 2026, it will be possible to temporarily use the old version of
jax.pmap
by doing one of the following:Setting the shell environment variable
JAX_PMAP_SHMAP_MERGE
to something false-like (e.g., 0);Setting the boolean flag
--jax_pmap_shmap_merge
to something false-like if your code parses flags withabsl-py
.Using this statement in your main file or anywhere before you call
jax.pmap
:NOTE: Please file a bug and tag @danielsuo with a reproducer so we can resolve it as quickly as possible under the new
jax.pmap
.How do I know I'm broken? What are some examples and fixes?
Please see the documentation for this change here. We include a number of typical issues that can come up and how to resolve them.
Beta Was this translation helpful? Give feedback.
All reactions