File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -38,24 +38,20 @@ jobs:
3838 - name : Run MaxDiffusion Training
3939 run : |
4040 # This command is adapted from your DAG for a single-slice configuration.
41- JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \
42- TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \
43- JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \
44- pip install . && \
45- python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
46- pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
47- revision=refs/pr/95 \
41+ NVTE_FUSED_ATTN=1 pip install . && \
42+ python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
43+ hardware=gpu \
44+ train_new_unet=true \
45+ train_text_encoder=false \
46+ cache_latents_text_encoder_outputs=true \
47+ per_device_batch_size=1 \
48+ attention=cudnn_flash_te \
4849 activations_dtype=bfloat16 \
4950 weights_dtype=bfloat16 \
50- dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \
51- resolution=1024 \
52- per_device_batch_size=1 \
53- jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \
54- max_train_steps=20 \
55- attention=flash \
51+ max_train_steps=200 \
5652 enable_profiler=True \
57- run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \
58- output_dir=gs://your-output-bucket/ maxdiffusion-jax-stable-stack /automated/${{ github.run_id }}
53+ run_name=1slice-VGpuVersion.XPK_H100_a3- maxdiffusion-jax- stable-stack-2025-09-26-04-12-02 \
54+ output_dir=gs://ml-auto-solutions/output/ maxdiffusion/automated/maxdiffusion_sdxl /${{ github.run_id }}
5955
6056# jobs:
6157# build:
You can’t perform that action at this time.
0 commit comments