[ Dont merge ] verifying jax ai image #3321
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Copyright 2024 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # This workflow will install Python dependencies, run tests and lint with a variety of Python versions | |
| # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python | |
| # This workflow will run a small SDXL training workload on a GPU runner. | |
| # This workflow will run a small SDXL training workload on a GPU runner. | |
| name: SDXL Workload Training on GPU | |
| on: | |
| pull_request: | |
| push: | |
| branches: [ "main" ] | |
| workflow_dispatch: | |
| jobs: | |
| sdxl_training_workload: | |
| name: "Run SDXL Training Workload" | |
| # IMPORTANT: Replace with the label for your specific GPU runner if different | |
| runs-on: ["linux-x86-a3-megagpu-h100-8gpu"] | |
| container: | |
| image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1 | |
| steps: | |
| - name: Create and Activate Swap File | |
| run: | | |
| echo "--- Verifying free space before changes ---" | |
| free -h | |
| echo "---" | |
| echo "Creating and activating a 64GB swap file..." | |
| # Deactivate any existing swap to be safe | |
| sudo swapoff -a | |
| # Allocate a 64GB file | |
| sudo fallocate -l 64G /swapfile | |
| # Set the correct permissions | |
| sudo chmod 600 /swapfile | |
| # Format the file as swap | |
| sudo mkswap /swapfile | |
| # Activate the swap file | |
| sudo swapon /swapfile | |
| echo "--- Swap file is now active ---" | |
| sudo swapon --show | |
| echo "--- Verifying free space after changes ---" | |
| free -h | |
| - name: Checkout Repository | |
| uses: actions/checkout@v4 | |
| - name: Install Dependencies | |
| run: | | |
| pip uninstall -y tensorflow | |
| pip install tensorflow-cpu | |
| pip install -r requirements.txt | |
| pip install --upgrade torch torchvision | |
| pip install . | |
| - name: List Installed Libraries | |
| run: | | |
| echo "--- Installed Python packages ---" | |
| pip freeze | |
| - name: Hugging Face Login | |
| run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }} | |
| - name: Run SDXL Training | |
| env: | |
| NVTE_FRAMEWORK: jax | |
| TF_FORCE_GPU_ALLOW_GROWTH: "true" | |
| run: | | |
| python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \ | |
| run_name="sdxl-ci-test-${{ github.run_id }}" \ | |
| output_dir="/tmp/sdxl-output/" \ | |
| max_train_steps=5 \ | |
| hardware=gpu \ | |
| attention="cudnn_flash_te" \ | |
| resolution=512 \ | |
| per_device_batch_size=1 \ | |
| train_new_unet=true \ | |
| train_text_encoder=false \ | |
| cache_latents_text_encoder_outputs=true |