Skip to content

Fix WAN training/inference error for new JAX/Flax Version#265

Merged
susanbao merged 1 commit intomainfrom
sanbao/test
Oct 9, 2025
Merged

Fix WAN training/inference error for new JAX/Flax Version#265
susanbao merged 1 commit intomainfrom
sanbao/test

Conversation

@susanbao
Copy link
Copy Markdown
Collaborator

@susanbao susanbao commented Oct 9, 2025

The newer version of JAX 0.7.2 and Flax 0.12.0 now strictly requires a mesh to be defined whenever you initialize parameters with sharding rules, even in a single-device unit test environment.

From Flax team, the issue is due to this change: https://github.com/google/flax/blob/main/docs_nnx/flip/4844-var-eager-sharding.md

Simplify the creation of sharded NNX models. When a sharding annotation is provided, all nnx.Variable creation will require a mesh context and automatically be sharded as annotated.

It can be disabled by using flax.config.update('flax_always_shard_variable', False).

In maxDiffusion, only WAN 2.1 use the flax modules.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Oct 9, 2025

@susanbao susanbao merged commit 158e1f2 into main Oct 9, 2025
3 of 4 checks passed
@susanbao susanbao deleted the sanbao/test branch October 13, 2025 05:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants