Skip to content

Commit 95c1b03

Browse files
Check jax.devices output
1 parent 9295b8b commit 95c1b03

1 file changed

Lines changed: 20 additions & 16 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -98,23 +98,27 @@ jobs:
9898
# pip install tensorflow-cpu
9999
pip freeze
100100
101-
- name: Run MaxDiffusion Training
101+
- name: Check per_device_batch_size
102102
run: |
103-
# This command is adapted from your DAG for a single-slice configuration.
104-
NVTE_FUSED_ATTN=1 pip install . && \
105-
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
106-
hardware=gpu \
107-
train_new_unet=true \
108-
train_text_encoder=false \
109-
cache_latents_text_encoder_outputs=true \
110-
per_device_batch_size=1 \
111-
attention=dot_product \
112-
activations_dtype=bfloat16 \
113-
weights_dtype=bfloat16 \
114-
max_train_steps=200 \
115-
enable_profiler=True \
116-
run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
117-
output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
103+
python -c "import jax; print(jax.devices())"
104+
105+
# - name: Run MaxDiffusion Training
106+
# run: |
107+
# # This command is adapted from your DAG for a single-slice configuration.
108+
# NVTE_FUSED_ATTN=1 pip install . && \
109+
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110+
# hardware=gpu \
111+
# train_new_unet=true \
112+
# train_text_encoder=false \
113+
# cache_latents_text_encoder_outputs=true \
114+
# per_device_batch_size=1 \
115+
# attention=dot_product \
116+
# activations_dtype=bfloat16 \
117+
# weights_dtype=bfloat16 \
118+
# max_train_steps=200 \
119+
# enable_profiler=True \
120+
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121+
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
118122

119123
# jobs:
120124
# build:

0 commit comments

Comments
 (0)