|
16 | 16 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python |
17 | 17 |
|
18 | 18 | # This workflow will run a small SDXL training workload on a GPU runner. |
19 | | -# This workflow will run a small SDXL training workload on a GPU runner. |
20 | 19 |
|
21 | | -name: SDXL Workload Training on GPU |
| 20 | +# This workflow will run a small MaxText training workload on a GPU runner |
| 21 | +# by checking out the MaxText repo inside the MaxDiffusion environment. |
| 22 | + |
| 23 | +name: MaxText Workload on MaxDiffusion Runner |
22 | 24 |
|
23 | 25 | on: |
24 | 26 | pull_request: |
|
27 | 29 | workflow_dispatch: |
28 | 30 |
|
29 | 31 | jobs: |
30 | | - sdxl_training_workload: |
31 | | - name: "Run SDXL Training Workload" |
32 | | - # IMPORTANT: Replace with the label for your specific GPU runner if different |
| 32 | + maxtext_training_workload: |
| 33 | + name: "Run MaxText Training Workload" |
33 | 34 | runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
34 | 35 | container: |
35 | | - image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 |
| 36 | + # Using the MaxDiffusion container as requested |
| 37 | + image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1:latest |
36 | 38 |
|
37 | 39 | steps: |
38 | | - - name: Verify Environment |
39 | | - run: | |
40 | | - echo "--- Verifying free space ---" |
41 | | - free -h |
42 | | - echo "--- Verifying shared memory size ---" |
43 | | - df -h /dev/shm |
44 | | -
|
45 | | - - name: Checkout Repository |
| 40 | + - name: Checkout MaxText Repository |
46 | 41 | uses: actions/checkout@v4 |
| 42 | + with: |
| 43 | + repository: 'AI-Hypercomputer/maxtext' |
| 44 | + ref: 'main' |
| 45 | + path: 'maxtext' # Clone it into a 'maxtext' subdirectory |
47 | 46 |
|
48 | 47 | - name: Install Dependencies |
| 48 | + working-directory: ./maxtext # Run all subsequent commands inside the new directory |
49 | 49 | run: | |
50 | | - pip install -r requirements.txt |
| 50 | + # Uninstall full tensorflow to prevent GPU conflicts with JAX |
51 | 51 | pip uninstall -y tensorflow |
| 52 | + # Install the CPU-only version for data loading |
52 | 53 | pip install tensorflow-cpu |
53 | | - pip install --upgrade torch torchvision |
| 54 | + # Install MaxText's dependencies |
| 55 | + pip install -r requirements.txt |
| 56 | + # Install the MaxText package itself |
54 | 57 | pip install . |
55 | | - |
| 58 | +
|
56 | 59 | - name: List Installed Libraries |
| 60 | + working-directory: ./maxtext |
57 | 61 | run: | |
58 | 62 | echo "--- Installed Python packages ---" |
59 | 63 | pip freeze |
60 | | - |
61 | | - - name: Hugging Face Login |
62 | | - run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} |
63 | 64 |
|
64 | | - - name: Run SDXL Training |
| 65 | + - name: Run MaxText Training |
| 66 | + working-directory: ./maxtext |
65 | 67 | env: |
| 68 | + # Set the correct framework for Transformer Engine |
66 | 69 | NVTE_FRAMEWORK: jax |
| 70 | + # Prevent TensorFlow from grabbing all GPU memory |
67 | 71 | TF_FORCE_GPU_ALLOW_GROWTH: "true" |
68 | 72 | run: | |
69 | | - python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
70 | | - run_name="sdxl-ci-test-${{ github.run_id }}" \ |
71 | | - output_dir="/tmp/sdxl-output/" \ |
72 | | - max_train_steps=5 \ |
73 | | - hardware=gpu \ |
74 | | - attention="cudnn_flash_te" \ |
75 | | - resolution=512 \ |
76 | | - per_device_batch_size=1 \ |
77 | | - train_new_unet=true \ |
78 | | - train_text_encoder=false \ |
79 | | - cache_latents_text_encoder_outputs=true |
| 73 | + # Run the main training script with a base configuration |
| 74 | + python MaxText/train.py MaxText/configs/base.yml \ |
| 75 | + run_name="maxtext-ci-test-${{ github.run_id }}" \ |
| 76 | + steps=5 \ |
| 77 | + enable_checkpointing=false \ |
| 78 | + attention='cudnn_flash_te' \ |
| 79 | + dataset_type='synthetic' |
| 80 | +
|
| 81 | +
|
| 82 | +# name: SDXL Workload Training on GPU |
| 83 | + |
| 84 | +# on: |
| 85 | +# pull_request: |
| 86 | +# push: |
| 87 | +# branches: [ "main" ] |
| 88 | +# workflow_dispatch: |
| 89 | + |
| 90 | +# jobs: |
| 91 | +# sdxl_training_workload: |
| 92 | +# name: "Run SDXL Training Workload" |
| 93 | +# # IMPORTANT: Replace with the label for your specific GPU runner if different |
| 94 | +# runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] |
| 95 | +# container: |
| 96 | +# image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 |
| 97 | + |
| 98 | +# steps: |
| 99 | +# - name: Verify Environment |
| 100 | +# run: | |
| 101 | +# echo "--- Verifying free space ---" |
| 102 | +# free -h |
| 103 | +# echo "--- Verifying shared memory size ---" |
| 104 | +# df -h /dev/shm |
| 105 | + |
| 106 | +# - name: Checkout Repository |
| 107 | +# uses: actions/checkout@v4 |
| 108 | + |
| 109 | +# - name: Install Dependencies |
| 110 | +# run: | |
| 111 | +# pip install -r requirements.txt |
| 112 | +# pip uninstall -y tensorflow |
| 113 | +# pip install tensorflow-cpu |
| 114 | +# pip install --upgrade torch torchvision |
| 115 | +# pip install . |
| 116 | + |
| 117 | +# - name: List Installed Libraries |
| 118 | +# run: | |
| 119 | +# echo "--- Installed Python packages ---" |
| 120 | +# pip freeze |
| 121 | + |
| 122 | +# - name: Hugging Face Login |
| 123 | +# run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} |
| 124 | + |
| 125 | +# - name: Run SDXL Training |
| 126 | +# env: |
| 127 | +# NVTE_FRAMEWORK: jax |
| 128 | +# TF_FORCE_GPU_ALLOW_GROWTH: "true" |
| 129 | +# run: | |
| 130 | +# python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ |
| 131 | +# run_name="sdxl-ci-test-${{ github.run_id }}" \ |
| 132 | +# output_dir="/tmp/sdxl-output/" \ |
| 133 | +# max_train_steps=5 \ |
| 134 | +# hardware=gpu \ |
| 135 | +# attention="cudnn_flash_te" \ |
| 136 | +# resolution=512 \ |
| 137 | +# per_device_batch_size=1 \ |
| 138 | +# train_new_unet=true \ |
| 139 | +# train_text_encoder=false \ |
| 140 | +# cache_latents_text_encoder_outputs=true |
0 commit comments