Skip to content

Fix Unit test failure for JAX/Flax version update#264

Merged
susanbao merged 1 commit intomainfrom
sanbao/flax
Oct 9, 2025
Merged

Fix Unit test failure for JAX/Flax version update#264
susanbao merged 1 commit intomainfrom
sanbao/flax

Conversation

@susanbao
Copy link
Copy Markdown
Collaborator

@susanbao susanbao commented Oct 9, 2025

The newer version of JAX 0.7.2 and Flax 0.12.0 now strictly requires a mesh to be defined whenever you initialize parameters with sharding rules, even in a single-device unit test environment. Our unit tests failed for this.

For Flax team, the issue is due to this change: https://github.com/google/flax/blob/main/docs_nnx/flip/4844-var-eager-sharding.md

Simplify the creation of sharded NNX models. When a sharding annotation is provided, all nnx.Variable creation will require a mesh context and automatically be sharded as annotated.

It can be disabled by using flax.config.update('flax_always_shard_variable', False)

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Oct 9, 2025

@susanbao susanbao changed the title Flax Fix Unit test failure for JAX/Flax version update Oct 9, 2025
@susanbao susanbao merged commit 972b4ff into main Oct 9, 2025
3 of 4 checks passed
@susanbao susanbao deleted the sanbao/flax branch October 13, 2025 05:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants