Skip to content

[ Dont merge ] verifying jax ai image #3321

[ Dont merge ] verifying jax ai image

[ Dont merge ] verifying jax ai image #3321

Workflow file for this run

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
# This workflow will run a small SDXL training workload on a GPU runner.
# This workflow will run a small SDXL training workload on a GPU runner.
name: SDXL Workload Training on GPU
on:
pull_request:
push:
branches: [ "main" ]
workflow_dispatch:
jobs:
sdxl_training_workload:
name: "Run SDXL Training Workload"
# IMPORTANT: Replace with the label for your specific GPU runner if different
runs-on: ["linux-x86-a3-megagpu-h100-8gpu"]
container:
image: us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/candidate/gpu:jax0.7.2-cuda12.9-rev1
steps:
- name: Create and Activate Swap File
run: |
echo "--- Verifying free space before changes ---"
free -h
echo "---"
echo "Creating and activating a 64GB swap file..."
# Deactivate any existing swap to be safe
sudo swapoff -a
# Allocate a 64GB file
sudo fallocate -l 64G /swapfile
# Set the correct permissions
sudo chmod 600 /swapfile
# Format the file as swap
sudo mkswap /swapfile
# Activate the swap file
sudo swapon /swapfile
echo "--- Swap file is now active ---"
sudo swapon --show
echo "--- Verifying free space after changes ---"
free -h
- name: Checkout Repository
uses: actions/checkout@v4
- name: Install Dependencies
run: |
pip uninstall -y tensorflow
pip install tensorflow-cpu
pip install -r requirements.txt
pip install --upgrade torch torchvision
pip install .
- name: List Installed Libraries
run: |
echo "--- Installed Python packages ---"
pip freeze
- name: Hugging Face Login
run: huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
- name: Run SDXL Training
env:
NVTE_FRAMEWORK: jax
TF_FORCE_GPU_ALLOW_GROWTH: "true"
run: |
python -m src.maxdiffusion.train_sdxl src/maxdiffusion/configs/base_xl.yml \
run_name="sdxl-ci-test-${{ github.run_id }}" \
output_dir="/tmp/sdxl-output/" \
max_train_steps=5 \
hardware=gpu \
attention="cudnn_flash_te" \
resolution=512 \
per_device_batch_size=1 \
train_new_unet=true \
train_text_encoder=false \
cache_latents_text_encoder_outputs=true