@@ -98,23 +98,27 @@ jobs:
9898 # pip install tensorflow-cpu
9999 pip freeze
100100
101- - name : Run MaxDiffusion Training
101+ - name : Check per_device_batch_size
102102 run : |
103- # This command is adapted from your DAG for a single-slice configuration.
104- NVTE_FUSED_ATTN=1 pip install . && \
105- python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
106- hardware=gpu \
107- train_new_unet=true \
108- train_text_encoder=false \
109- cache_latents_text_encoder_outputs=true \
110- per_device_batch_size=1 \
111- attention=dot_product \
112- activations_dtype=bfloat16 \
113- weights_dtype=bfloat16 \
114- max_train_steps=200 \
115- enable_profiler=True \
116- run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
117- output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
103+ python -c "import jax; print(jax.devices())"
104+
105+ # - name: Run MaxDiffusion Training
106+ # run: |
107+ # # This command is adapted from your DAG for a single-slice configuration.
108+ # NVTE_FUSED_ATTN=1 pip install . && \
109+ # python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110+ # hardware=gpu \
111+ # train_new_unet=true \
112+ # train_text_encoder=false \
113+ # cache_latents_text_encoder_outputs=true \
114+ # per_device_batch_size=1 \
115+ # attention=dot_product \
116+ # activations_dtype=bfloat16 \
117+ # weights_dtype=bfloat16 \
118+ # max_train_steps=200 \
119+ # enable_profiler=True \
120+ # run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121+ # output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
118122
119123# jobs:
120124# build:
0 commit comments