Skip to content

Commit 9172655

Browse files
authored
Update UnitTests.yml
1 parent 18b8e2b commit 9172655

1 file changed

Lines changed: 93 additions & 32 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
1717

1818
# This workflow will run a small SDXL training workload on a GPU runner.
19-
# This workflow will run a small SDXL training workload on a GPU runner.
2019

21-
name: SDXL Workload Training on GPU
20+
# This workflow will run a small MaxText training workload on a GPU runner
21+
# by checking out the MaxText repo inside the MaxDiffusion environment.
22+
23+
name: MaxText Workload on MaxDiffusion Runner
2224

2325
on:
2426
pull_request:
@@ -27,53 +29,112 @@ on:
2729
workflow_dispatch:
2830

2931
jobs:
30-
sdxl_training_workload:
31-
name: "Run SDXL Training Workload"
32-
# IMPORTANT: Replace with the label for your specific GPU runner if different
32+
maxtext_training_workload:
33+
name: "Run MaxText Training Workload"
3334
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
3435
container:
35-
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1
36+
# Using the MaxDiffusion container as requested
37+
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest
3638

3739
steps:
38-
- name: Verify Environment
39-
run: |
40-
echo "--- Verifying free space ---"
41-
free -h
42-
echo "--- Verifying shared memory size ---"
43-
df -h /dev/shm
44-
45-
- name: Checkout Repository
40+
- name: Checkout MaxText Repository
4641
uses: actions/checkout@v4
42+
with:
43+
repository: 'AI-Hypercomputer/maxtext'
44+
ref: 'main'
45+
path: 'maxtext' # Clone it into a 'maxtext' subdirectory
4746

4847
- name: Install Dependencies
48+
working-directory: ./maxtext # Run all subsequent commands inside the new directory
4949
run: |
50-
pip install -r requirements.txt
50+
# Uninstall full tensorflow to prevent GPU conflicts with JAX
5151
pip uninstall -y tensorflow
52+
# Install the CPU-only version for data loading
5253
pip install tensorflow-cpu
53-
pip install --upgrade torch torchvision
54+
# Install MaxText's dependencies
55+
pip install -r requirements.txt
56+
# Install the MaxText package itself
5457
pip install .
55-
58+
5659
- name: List Installed Libraries
60+
working-directory: ./maxtext
5761
run: |
5862
echo "--- Installed Python packages ---"
5963
pip freeze
60-
61-
- name: Hugging Face Login
62-
run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
6364
64-
- name: Run SDXL Training
65+
- name: Run MaxText Training
66+
working-directory: ./maxtext
6567
env:
68+
# Set the correct framework for Transformer Engine
6669
NVTE_FRAMEWORK: jax
70+
# Prevent TensorFlow from grabbing all GPU memory
6771
TF_FORCE_GPU_ALLOW_GROWTH: "true"
6872
run: |
69-
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
70-
run_name="sdxl-ci-test-${{ github.run_id }}" \
71-
output_dir="/tmp/sdxl-output/" \
72-
max_train_steps=5 \
73-
hardware=gpu \
74-
attention="cudnn_flash_te" \
75-
resolution=512 \
76-
per_device_batch_size=1 \
77-
train_new_unet=true \
78-
train_text_encoder=false \
79-
cache_latents_text_encoder_outputs=true
73+
# Run the main training script with a base configuration
74+
python MaxText/train.py MaxText/configs/base.yml \
75+
run_name="maxtext-ci-test-${{ github.run_id }}" \
76+
steps=5 \
77+
enable_checkpointing=false \
78+
attention='cudnn_flash_te' \
79+
dataset_type='synthetic'
80+
81+
82+
# name: SDXL Workload Training on GPU
83+
84+
# on:
85+
# pull_request:
86+
# push:
87+
# branches: [ "main" ]
88+
# workflow_dispatch:
89+
90+
# jobs:
91+
# sdxl_training_workload:
92+
# name: "Run SDXL Training Workload"
93+
# # IMPORTANT: Replace with the label for your specific GPU runner if different
94+
# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
95+
# container:
96+
# image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1
97+
98+
# steps:
99+
# - name: Verify Environment
100+
# run: |
101+
# echo "--- Verifying free space ---"
102+
# free -h
103+
# echo "--- Verifying shared memory size ---"
104+
# df -h /dev/shm
105+
106+
# - name: Checkout Repository
107+
# uses: actions/checkout@v4
108+
109+
# - name: Install Dependencies
110+
# run: |
111+
# pip install -r requirements.txt
112+
# pip uninstall -y tensorflow
113+
# pip install tensorflow-cpu
114+
# pip install --upgrade torch torchvision
115+
# pip install .
116+
117+
# - name: List Installed Libraries
118+
# run: |
119+
# echo "--- Installed Python packages ---"
120+
# pip freeze
121+
122+
# - name: Hugging Face Login
123+
# run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
124+
125+
# - name: Run SDXL Training
126+
# env:
127+
# NVTE_FRAMEWORK: jax
128+
# TF_FORCE_GPU_ALLOW_GROWTH: "true"
129+
# run: |
130+
# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
131+
# run_name="sdxl-ci-test-${{ github.run_id }}" \
132+
# output_dir="/tmp/sdxl-output/" \
133+
# max_train_steps=5 \
134+
# hardware=gpu \
135+
# attention="cudnn_flash_te" \
136+
# resolution=512 \
137+
# per_device_batch_size=1 \
138+
# train_new_unet=true \
139+
# train_text_encoder=false \
140+
# cache_latents_text_encoder_outputs=true

0 commit comments

Comments
 (0)