Skip to content

Commit d9617a0

Browse files
committed
fix
1 parent a56787c commit d9617a0

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

src/maxdiffusion/tests/ltx_2_transformer_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import jax.numpy as jnp
55
from flax import nnx
66
from flax.linen import partitioning as nn_partitioning
7+
import flax
8+
# Matches WanTransformerTest: disable eager sharding to avoid "mesh context required" errors during init
9+
flax.config.update("flax_always_shard_variable", False)
710
from jax.sharding import Mesh
811
import os
912
from maxdiffusion import pyconfig

0 commit comments

Comments
 (0)