Skip to content

Commit 333e4f5

Browse files
try with workload
1 parent 95c1b03 commit 333e4f5

1 file changed

Lines changed: 17 additions & 17 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,23 +102,23 @@ jobs:
102102
run: |
103103
python -c "import jax; print(jax.devices())"
104104
105-
# - name: Run MaxDiffusion Training
106-
# run: |
107-
# # This command is adapted from your DAG for a single-slice configuration.
108-
# NVTE_FUSED_ATTN=1 pip install . && \
109-
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110-
# hardware=gpu \
111-
# train_new_unet=true \
112-
# train_text_encoder=false \
113-
# cache_latents_text_encoder_outputs=true \
114-
# per_device_batch_size=1 \
115-
# attention=dot_product \
116-
# activations_dtype=bfloat16 \
117-
# weights_dtype=bfloat16 \
118-
# max_train_steps=200 \
119-
# enable_profiler=True \
120-
# run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121-
# output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
105+
- name: Run MaxDiffusion Training
106+
run: |
107+
# This command is adapted from your DAG for a single-slice configuration.
108+
NVTE_FRAMEWORK=JAX NVTE_FUSED_ATTN=1 pip install . && \
109+
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
110+
hardware=gpu \
111+
train_new_unet=true \
112+
train_text_encoder=false \
113+
cache_latents_text_encoder_outputs=true \
114+
per_device_batch_size=1 \
115+
attention=dot_product \
116+
activations_dtype=bfloat16 \
117+
weights_dtype=bfloat16 \
118+
max_train_steps=200 \
119+
enable_profiler=True \
120+
run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
121+
output_dir=gs://rbierneni-multipod-dev/${{ github.run_id }}
122122
123123
# jobs:
124124
# build:

0 commit comments

Comments
 (0)