Skip to content

Commit 2deaa7e

Browse files
Update with command for gpu
1 parent 1f22cf5 commit 2deaa7e

1 file changed

Lines changed: 11 additions & 15 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,20 @@ jobs:
3838
- name: Run MaxDiffusion Training
3939
run: |
4040
# This command is adapted from your DAG for a single-slice configuration.
41-
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \
42-
TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \
43-
JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \
44-
pip install . && \
45-
python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
46-
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
47-
revision=refs/pr/95 \
41+
NVTE_FUSED_ATTN=1 pip install . && \
42+
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
43+
hardware=gpu \
44+
train_new_unet=true \
45+
train_text_encoder=false \
46+
cache_latents_text_encoder_outputs=true \
47+
per_device_batch_size=1 \
48+
attention=cudnn_flash_te \
4849
activations_dtype=bfloat16 \
4950
weights_dtype=bfloat16 \
50-
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \
51-
resolution=1024 \
52-
per_device_batch_size=1 \
53-
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \
54-
max_train_steps=20 \
55-
attention=flash \
51+
max_train_steps=200 \
5652
enable_profiler=True \
57-
run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \
58-
output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }}
53+
run_name=1slice-VGpuVersion.XPK_H100_a3-maxdiffusion-jax-stable-stack-2025-09-26-04-12-02 \
54+
output_dir=gs://ml-auto-solutions/output/maxdiffusion/automated/maxdiffusion_sdxl/${{ github.run_id }}
5955
6056
# jobs:
6157
# build:

0 commit comments

Comments
 (0)