diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index b2523d6c3..331caeb3a 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -44,7 +44,7 @@ jobs: run: docker system prune --all --force - name: build maxdiffusion jax stable stack gpu image run: | - bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu MODE=stable_stack PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack_gpu BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:latest DEVICE=gpu + bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_gpu MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_gpu DEVICE=gpu - name: build maxdiffusion jax nightly image run: | bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly_gpu MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly DEVICE=gpu diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index ac20ed6d6..00ee172cf 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -92,6 +92,7 @@ diffusion_scheduler_config: { # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False base_output_directory: "" diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index a1450abee..f5a05b0e4 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -91,6 +91,7 @@ diffusion_scheduler_config: { # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index ede91b107..1113d03b6 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -104,6 +104,7 @@ diffusion_scheduler_config: { # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 49146e10d..3a5f294a9 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -120,6 +120,7 @@ base_output_directory: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 6928b31d2..c37923911 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -120,6 +120,7 @@ base_output_directory: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 526593309..8e8db4a44 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -128,6 +128,7 @@ base_output_directory: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/configs/base_wan_t2v.yml b/src/maxdiffusion/configs/base_wan_t2v.yml index 28ef6e77e..002b9e73b 100644 --- a/src/maxdiffusion/configs/base_wan_t2v.yml +++ b/src/maxdiffusion/configs/base_wan_t2v.yml @@ -114,6 +114,7 @@ base_output_directory: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' +skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index 542b7957c..e773c19e0 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -95,7 +95,7 @@ base_output_directory: "" # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' - +skip_jax_distributed_system: False # Parallelism mesh_axes: ['data', 'fsdp', 'tensor'] diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index 307e826a8..aafeea2bd 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -71,7 +71,7 @@ diffusion_scheduler_config: { # Hardware hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' - +skip_jax_distributed_system: False # Output directory # Create a GCS bucket, e.g. my-maxtext-outputs and set this to "gs://my-maxtext-outputs/" base_output_directory: "" diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index a28bcff13..fab895f97 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -601,6 +601,9 @@ def initialize_jax_for_gpu(): def maybe_initialize_jax_distributed_system(raw_keys): + if raw_keys["skip_jax_distributed_system"]: + max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") + return if is_gpu_backend(raw_keys): max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") initialize_jax_for_gpu()