Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl import app
from maxdiffusion.utils import export_to_video
from google.cloud import storage
import flax


def upload_video_to_gcs(output_dir: str, video_path: str):
Expand Down Expand Up @@ -161,6 +162,7 @@ def run(config, pipeline=None, filename_prefix=""):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
flax.config.update('flax_always_shard_variable', False)
run(pyconfig.config)


Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/train_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from absl import app
from maxdiffusion import max_logging, pyconfig
from maxdiffusion.train_utils import validate_train_config
import flax


def train(config):
Expand All @@ -34,6 +35,7 @@ def main(argv: Sequence[str]) -> None:
config = pyconfig.config
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
flax.config.update('flax_always_shard_variable', False)
train(config)


Expand Down
Loading