@@ -102,23 +102,23 @@ jobs:
102102 run : |
103103 python -c "import jax; print(jax.devices())"
104104
105- # - name: Run MaxDiffusion Training
106- # run: |
107- # # This command is adapted from your DAG for a single-slice configuration.
108- # NVTE_FUSED_ATTN=1 pip install . && \
109- # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110- # hardware=gpu \
111- # train_new_unet=true \
112- # train_text_encoder=false \
113- # cache_latents_text_encoder_outputs=true \
114- # per_device_batch_size=1 \
115- # attention=dot_product \
116- # activations_dtype=bfloat16 \
117- # weights_dtype=bfloat16 \
118- # max_train_steps=200 \
119- # enable_profiler=True \
120- # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121- # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
105+ - name : Run MaxDiffusion Training
106+ run : |
107+ # This command is adapted from your DAG for a single-slice configuration.
108+ NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
109+ python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110+ hardware=gpu \
111+ train_new_unet=true \
112+ train_text_encoder=false \
113+ cache_latents_text_encoder_outputs=true \
114+ per_device_batch_size=1 \
115+ attention=dot_product \
116+ activations_dtype=bfloat16 \
117+ weights_dtype=bfloat16 \
118+ max_train_steps=200 \
119+ enable_profiler=True \
120+ run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121+ output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
122122
123123# jobs:
124124# build:
0 commit comments