@@ -98,27 +98,28 @@ jobs:
9898 # pip install tensorflow-cpu
9999 pip freeze
100100
101- - name : Check per_device_batch_size
101+ - name : Check devices
102102 run : |
103103 python -c "import jax; print(jax.devices())"
104+ python verify_conflict.py
104105
105- - name : Run MaxDiffusion Training
106- run : |
107- # This command is adapted from your DAG for a single-slice configuration.
108- NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
109- python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110- hardware=gpu \
111- train_new_unet=true \
112- train_text_encoder=false \
113- cache_latents_text_encoder_outputs=true \
114- per_device_batch_size=1 \
115- attention=dot_product \
116- activations_dtype=bfloat16 \
117- weights_dtype=bfloat16 \
118- max_train_steps=200 \
119- enable_profiler=True \
120- run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121- output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
106+ # - name: Run MaxDiffusion Training
107+ # run: |
108+ # # This command is adapted from your DAG for a single-slice configuration.
109+ # NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
110+ # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
111+ # hardware=gpu \
112+ # train_new_unet=true \
113+ # train_text_encoder=false \
114+ # cache_latents_text_encoder_outputs=true \
115+ # per_device_batch_size=1 \
116+ # attention=dot_product \
117+ # activations_dtype=bfloat16 \
118+ # weights_dtype=bfloat16 \
119+ # max_train_steps=200 \
120+ # enable_profiler=True \
121+ # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
122+ # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
122123
123124# jobs:
124125# build:
0 commit comments