Skip to content

Commit 03d4bae

Browse files
authored
Update UnitTests.yml
1 parent 70beed7 commit 03d4bae

1 file changed

Lines changed: 14 additions & 9 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
1616
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
1717

18-
# This workflow will run a small FLUX training workload on a GPU runner.
18+
# This workflow will run a small SDXL training workload on a GPU runner.
1919

20-
name: FLUX Workload Training on GPU
20+
name: SDXL Workload Training on GPU
2121

2222
on:
2323
pull_request:
@@ -26,8 +26,8 @@ on:
2626
workflow_dispatch:
2727

2828
jobs:
29-
flux_training_workload:
30-
name: "Run FLUX Training Workload"
29+
sdxl_training_workload:
30+
name: "Run SDXL Training Workload"
3131
# IMPORTANT: Replace with the label for your specific GPU runner if different
3232
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
3333
container:
@@ -48,17 +48,22 @@ jobs:
4848
run: |
4949
echo "--- Installed Python packages ---"
5050
pip freeze
51-
51+
5252
- name: Hugging Face Login
5353
run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
5454

55-
- name: Run FLUX Training
55+
- name: Run SDXL Training
5656
env:
5757
NVTE_FRAMEWORK: jax
5858
run: |
59-
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
60-
run_name="flux-ci-test-${{ github.run_id }}" \
61-
output_dir="/tmp/flux-output/" \
59+
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
60+
run_name="sdxl-ci-test-${{ github.run_id }}" \
61+
output_dir="/tmp/sdxl-output/" \
6262
max_train_steps=5 \
6363
hardware=gpu \
6464
attention="cudnn_flash_te" \
65+
resolution=512 \
66+
per_device_batch_size=1 \
67+
train_new_unet=true \
68+
train_text_encoder=false \
69+
cache_latents_text_encoder_outputs=true

0 commit comments

Comments
 (0)