Skip to content

Commit 00d413e

Browse files
committed
jax distributed init
1 parent 2ad2b97 commit 00d413e

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

src/maxdiffusion/tests/wan_magcache_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323

2424
from maxdiffusion import pyconfig
2525
from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1
26+
import jax
2627

28+
try:
29+
jax.distributed.initialize()
30+
except Exception:
31+
pass
2732

2833
IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"
2934
THIS_DIR = os.path.dirname(os.path.abspath(__file__))

0 commit comments

Comments
 (0)