|
22 | 22 | push: |
23 | 23 | branches: [ "main" ] |
24 | 24 | workflow_dispatch: |
25 | | - schedule: |
26 | | - # Run the job every 12 hours |
27 | | - - cron: '0 */12 * * *' |
28 | 25 |
|
29 | 26 | jobs: |
30 | | - build: |
31 | | - strategy: |
32 | | - fail-fast: false |
33 | | - matrix: |
34 | | - tpu-type: ["v5p-8"] |
35 | | - name: "TPU test (${{ matrix.tpu-type }})" |
36 | | - runs-on: ["self-hosted","${{ matrix.tpu-type }}"] |
| 27 | + # STAGE 1: PULL MAXDIFFUSION IMAGE AND RUN WORKLOAD |
| 28 | + maxdiffusion_workload: |
| 29 | + name: "Run MaxDiffusion Workload" |
| 30 | + # IMPORTANT: Replace with the label for your runner (e.g., v5p-8) |
| 31 | + runs-on: ["self-hosted", "linux-x86-a2-48-a100-4gpu"] |
| 32 | + container: |
| 33 | + image: gcr.io/tpu-prod-env-multipod/maxdiffusion_stable_stack_candidate:latest |
37 | 34 | steps: |
38 | | - - uses: actions/checkout@v4 |
39 | | - - name: Set up Python 3.12 |
40 | | - uses: actions/setup-python@v5 |
41 | | - with: |
42 | | - python-version: '3.12' |
43 | | - - name: Install dependencies |
44 | | - run: | |
45 | | - pip install -e . |
46 | | - pip uninstall jax jaxlib libtpu-nightly libtpu -y |
47 | | - bash setup.sh MODE=stable |
48 | | - export PATH=$PATH:$HOME/.local/bin |
49 | | - pip install ruff |
50 | | - pip install isort |
51 | | - pip install pytest |
52 | | - - name: Analysing the code with ruff |
53 | | - run: | |
54 | | - ruff check . |
55 | | - - name: version check |
56 | | - run: | |
57 | | - python --version |
58 | | - pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets |
59 | | - - name: PyTest |
60 | | - run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py |
61 | | - HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x |
| 35 | + - name: Checkout Repository |
| 36 | + uses: actions/checkout@v4 |
| 37 | + |
| 38 | + - name: Run MaxDiffusion Training |
| 39 | + run: | |
| 40 | + # This command is adapted from your DAG for a single-slice configuration. |
| 41 | + JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true \ |
| 42 | + TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true \ |
| 43 | + JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && \ |
| 44 | + pip install . && \ |
| 45 | + python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml \ |
| 46 | + pretrained_model_name_or_path=gs://maxdiffusion-github-runner-test-assets/checkpoints/models--stabilityai--stable-diffusion-xl-base-1.0 \ |
| 47 | + revision=refs/pr/95 \ |
| 48 | + activations_dtype=bfloat16 \ |
| 49 | + weights_dtype=bfloat16 \ |
| 50 | + dataset_name=gs://jfacevedo-maxdiffusion-v5p/pokemon-datasets/pokemon-gpt4-captions_sdxl \ |
| 51 | + resolution=1024 \ |
| 52 | + per_device_batch_size=1 \ |
| 53 | + jax_cache_dir=gs://jfacevedo-maxdiffusion/cache_dir/ \ |
| 54 | + max_train_steps=20 \ |
| 55 | + attention=flash \ |
| 56 | + enable_profiler=True \ |
| 57 | + run_name=1slice-maxdiffusion-stable-stack-${{ github.run_id }} \ |
| 58 | + output_dir=gs://your-output-bucket/maxdiffusion-jax-stable-stack/automated/${{ github.run_id }} |
| 59 | +
|
| 60 | +# jobs: |
| 61 | +# build: |
| 62 | +# strategy: |
| 63 | +# fail-fast: false |
| 64 | +# matrix: |
| 65 | +# tpu-type: ["v5p-8"] |
| 66 | +# name: "TPU test (${{ matrix.tpu-type }})" |
| 67 | +# runs-on: ["self-hosted","${{ matrix.tpu-type }}"] |
| 68 | +# steps: |
| 69 | +# - uses: actions/checkout@v4 |
| 70 | +# - name: Set up Python 3.12 |
| 71 | +# uses: actions/setup-python@v5 |
| 72 | +# with: |
| 73 | +# python-version: '3.12' |
| 74 | +# - name: Install dependencies |
| 75 | +# run: | |
| 76 | +# pip install -e . |
| 77 | +# pip uninstall jax jaxlib libtpu-nightly libtpu -y |
| 78 | +# bash setup.sh MODE=stable |
| 79 | +# export PATH=$PATH:$HOME/.local/bin |
| 80 | +# pip install ruff |
| 81 | +# pip install isort |
| 82 | +# pip install pytest |
| 83 | +# - name: Analysing the code with ruff |
| 84 | +# run: | |
| 85 | +# ruff check . |
| 86 | +# - name: version check |
| 87 | +# run: | |
| 88 | +# python --version |
| 89 | +# pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets |
| 90 | +# - name: PyTest |
| 91 | +# run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py |
| 92 | +# HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x |
62 | 93 | # add_pull_ready: |
63 | 94 | # if: github.ref != 'refs/heads/main' |
64 | 95 | # permissions: |
|
0 commit comments