diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 1427d019b..451b28292 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -21,6 +21,7 @@ from absl import app from maxdiffusion.utils import export_to_video from google.cloud import storage +import flax def upload_video_to_gcs(output_dir: str, video_path: str): @@ -161,6 +162,7 @@ def run(config, pipeline=None, filename_prefix=""): def main(argv: Sequence[str]) -> None: pyconfig.initialize(argv) + flax.config.update('flax_always_shard_variable', False) run(pyconfig.config) diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index 8d4774987..2fbb069d3 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -20,6 +20,7 @@ from absl import app from maxdiffusion import max_logging, pyconfig from maxdiffusion.train_utils import validate_train_config +import flax def train(config): @@ -34,6 +35,7 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.config validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") + flax.config.update('flax_always_shard_variable', False) train(config)