We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent a56787c commit d9617a0Copy full SHA for d9617a0
1 file changed
src/maxdiffusion/tests/ltx_2_transformer_test.py
@@ -4,6 +4,9 @@
4
import jax.numpy as jnp
5
from flax import nnx
6
from flax.linen import partitioning as nn_partitioning
7
+import flax
8
+# Matches WanTransformerTest: disable eager sharding to avoid "mesh context required" errors during init
9
+flax.config.update("flax_always_shard_variable", False)
10
from jax.sharding import Mesh
11
import os
12
from maxdiffusion import pyconfig
0 commit comments