Skip to content

[Do not Merge] Test maxdiffusion workload on gpu runner #3230

[Do not Merge] Test maxdiffusion workload on gpu runner

[Do not Merge] Test maxdiffusion workload on gpu runner #3230

Workflow file for this run

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Unit Test
on:
pull_request:
push:
branches: [ "main" ]
workflow_dispatch:
jobs:
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
maxdiffusion_workload:
name: "Run MaxDiffusion Workload"
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
runs-on: ["linux-x86-a2-48-a100-4gpu"]
container:
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Run MaxDiffusion Training
run: |
# This command is adapted from your DAG for a single-slice configuration.
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \
TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \
JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \
pip install . && \
python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
revision=refs/pr/95 \
activations_dtype=bfloat16 \
weights_dtype=bfloat16 \
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \
resolution=1024 \
per_device_batch_size=1 \
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \
max_train_steps=20 \
attention=flash \
enable_profiler=True \
run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \
output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }}
# jobs:
# build:
# strategy:
# fail-fast: false
# matrix:
# tpu-type: ["v5p-8"]
# name: "TPU test (${{ matrix.tpu-type }})"
# runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python 3.12
# uses: actions/setup-python@v5
# with:
# python-version: '3.12'
# - name: Install dependencies
# run: |
# pip install -e .
# pip uninstall jax jaxlib libtpu-nightly libtpu -y
# bash setup.sh MODE=stable
# export PATH=$PATH:$HOME/.local/bin
# pip install ruff
# pip install isort
# pip install pytest
# - name: Analysing the code with ruff
# run: |
# ruff check .
# - name: version check
# run: |
# python --version
# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
# - name: PyTest
# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
# add_pull_ready:
# if: github.ref != 'refs/heads/main'
# permissions:
# checks: read
# pull-requests: write
# needs: build
# uses: ./.github/workflows/AddLabel.yml