Skip to content

Commit 6e0898a

Browse files
Update entrypoint for jaii
Test maxdiffusion workload on gpu image
1 parent 75d16a4 commit 6e0898a

2 files changed

Lines changed: 66 additions & 35 deletions

File tree

.github/workflows/UnitTests.yml

Lines changed: 65 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,74 @@ on:
2222
push:
2323
branches: [ "main" ]
2424
workflow_dispatch:
25-
schedule:
26-
# Run the job every 12 hours
27-
- cron: '0 */12 * * *'
2825

2926
jobs:
30-
build:
31-
strategy:
32-
fail-fast: false
33-
matrix:
34-
tpu-type: ["v5p-8"]
35-
name: "TPU test (${{ matrix.tpu-type }})"
36-
runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
27+
# STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD
28+
maxdiffusion_workload:
29+
name: "Run MaxDiffusion Workload"
30+
# IMPORTANT: Replace with the label for your runner (e.g., v5p-8)
31+
runs-on: ["self-hosted", "linux-x86-a2-48-a100-4gpu"]
32+
container:
33+
image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest
3734
steps:
38-
- uses: actions/checkout@v4
39-
- name: Set up Python 3.12
40-
uses: actions/setup-python@v5
41-
with:
42-
python-version: '3.12'
43-
- name: Install dependencies
44-
run: |
45-
pip install -e .
46-
pip uninstall jax jaxlib libtpu-nightly libtpu -y
47-
bash setup.sh MODE=stable
48-
export PATH=$PATH:$HOME/.local/bin
49-
pip install ruff
50-
pip install isort
51-
pip install pytest
52-
- name: Analysing the code with ruff
53-
run: |
54-
ruff check .
55-
- name: version check
56-
run: |
57-
python --version
58-
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
59-
- name: PyTest
60-
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
35+
- name: Checkout Repository
36+
uses: actions/checkout@v4
37+
38+
- name: Run MaxDiffusion Training
39+
run: |
40+
# This command is adapted from your DAG for a single-slice configuration.
41+
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \
42+
TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \
43+
JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \
44+
pip install . && \
45+
python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \
46+
pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \
47+
revision=refs/pr/95 \
48+
activations_dtype=bfloat16 \
49+
weights_dtype=bfloat16 \
50+
dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \
51+
resolution=1024 \
52+
per_device_batch_size=1 \
53+
jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \
54+
max_train_steps=20 \
55+
attention=flash \
56+
enable_profiler=True \
57+
run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \
58+
output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }}
59+
60+
# jobs:
61+
# build:
62+
# strategy:
63+
# fail-fast: false
64+
# matrix:
65+
# tpu-type: ["v5p-8"]
66+
# name: "TPU test (${{ matrix.tpu-type }})"
67+
# runs-on: ["self-hosted","${{ matrix.tpu-type }}"]
68+
# steps:
69+
# - uses: actions/checkout@v4
70+
# - name: Set up Python 3.12
71+
# uses: actions/setup-python@v5
72+
# with:
73+
# python-version: '3.12'
74+
# - name: Install dependencies
75+
# run: |
76+
# pip install -e .
77+
# pip uninstall jax jaxlib libtpu-nightly libtpu -y
78+
# bash setup.sh MODE=stable
79+
# export PATH=$PATH:$HOME/.local/bin
80+
# pip install ruff
81+
# pip install isort
82+
# pip install pytest
83+
# - name: Analysing the code with ruff
84+
# run: |
85+
# ruff check .
86+
# - name: version check
87+
# run: |
88+
# python --version
89+
# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
90+
# - name: PyTest
91+
# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
92+
# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
6293
# add_pull_ready:
6394
# if: github.ref != 'refs/heads/main'
6495
# permissions:

maxdiffusion_jax_ai_image_tpu.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ COPY . .
1919
RUN pip install -r /deps/requirements_with_jax_ai_image.txt
2020

2121
# Run the script available in JAX-AI-Image base image to generate the manifest file
22-
RUN bash /jax-stable-stack/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH
22+
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH

0 commit comments

Comments
 (0)