1515# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
1616# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
1717
18- # This workflow will run a small FLUX training workload on a GPU runner.
18+ # This workflow will run a small SDXL training workload on a GPU runner.
1919
20- name : FLUX Workload Training on GPU
20+ name : SDXL Workload Training on GPU
2121
2222on :
2323 pull_request :
2626 workflow_dispatch :
2727
2828jobs :
29- flux_training_workload :
30- name : " Run FLUX Training Workload"
29+ sdxl_training_workload :
30+ name : " Run SDXL Training Workload"
3131 # IMPORTANT: Replace with the label for your specific GPU runner if different
3232 runs-on : ["linux-x86-a3-megagpu-h100-8gpu"]
3333 container :
@@ -48,17 +48,22 @@ jobs:
4848 run : |
4949 echo "--- Installed Python packages ---"
5050 pip freeze
51-
51+
5252 - name : Hugging Face Login
5353 run : huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
5454
55- - name : Run FLUX Training
55+ - name : Run SDXL Training
5656 env :
5757 NVTE_FRAMEWORK : jax
5858 run : |
59- python src/ maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev .yml \
60- run_name="flux -ci-test-${{ github.run_id }}" \
61- output_dir="/tmp/flux -output/" \
59+ python -m src. maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl .yml \
60+ run_name="sdxl -ci-test-${{ github.run_id }}" \
61+ output_dir="/tmp/sdxl -output/" \
6262 max_train_steps=5 \
6363 hardware=gpu \
6464 attention="cudnn_flash_te" \
65+ resolution=512 \
66+ per_device_batch_size=1 \
67+ train_new_unet=true \
68+ train_text_encoder=false \
69+ cache_latents_text_encoder_outputs=true
0 commit comments