Skip to content

Commit 3331fed

Browse files
committed
Fix transformer sharding, flash block sizing, and tests
1 parent b4f9573 commit 3331fed

41 files changed

Lines changed: 1902 additions & 708 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/UnitTests.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ jobs:
4242
python-version: '3.12'
4343
- name: Install dependencies
4444
run: |
45-
pip install -e .
46-
pip uninstall jax jaxlib libtpu-nightly libtpu -y
4745
bash setup.sh MODE=stable
4846
export PATH=$PATH:$HOME/.local/bin
4947
pip install ruff
@@ -66,4 +64,4 @@ jobs:
6664
# checks: read
6765
# pull-requests: write
6866
# needs: build
69-
# uses: ./.github/workflows/AddLabel.yml
67+
# uses: ./.github/workflows/AddLabel.yml

.github/workflows/UploadDockerImages.yml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,12 @@ jobs:
2828
build-image:
2929
runs-on: ["self-hosted", "e2", "cpu"]
3030
steps:
31-
- uses: actions/checkout@v3
31+
- uses: actions/checkout@v5
3232
- name: Cleanup old docker images
3333
run: docker system prune --all --force
34-
- name: build maxdiffusion jax ai image
34+
- name: build maxdiffusion stable image
3535
run: |
36-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
37-
- name: build maxdiffusion w/ nightly jax ai image
38-
run: |
39-
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack_nightly MODE=jax_ai_image PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/tpu-prod-env-multipod/jax-stable-stack/tpu/jax_nightly:latest
40-
- name: build maxdiffusion jax nightly image
36+
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable MODE=stable PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_stable
37+
- name: build maxdiffusion nightly image
4138
run: |
4239
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_nightly MODE=nightly PROJECT=tpu-prod-env-multipod LOCAL_IMAGE_NAME=maxdiffusion_jax_nightly

.github/workflows/pypi_release.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Google LLC
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This workflow will build, test and automatically release MaxDiffusion package to PyPI using Trusted Publishing (OIDC).
16+
17+
name: Publish MaxDiffusion to PyPI
18+
19+
# Triggers when a new "release" is published in the GitHub UI
20+
on:
21+
release:
22+
types: [published]
23+
workflow_dispatch:
24+
25+
permissions:
26+
contents: read
27+
id-token: write
28+
29+
jobs:
30+
build_and_publish:
31+
name: Build and Publish MaxDiffusion Package
32+
runs-on: ubuntu-latest
33+
steps:
34+
- uses: actions/checkout@v5
35+
- name: Set up Python
36+
uses: actions/setup-python@v5
37+
with:
38+
python-version: '3.12'
39+
- name: Install build dependencies
40+
run: |
41+
python -m pip install --upgrade pip
42+
pip install build hatchling hatch-requirements-txt
43+
- name: Build package
44+
run: python -m build
45+
- name: Publish package
46+
uses: pypa/gh-action-pypi-publish@release/v1
47+
with:
48+
packages-dir: dist/

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ modified_only_fixup:
1818
# Update src/maxdiffusion/dependency_versions_table.py
1919

2020
deps_table_update:
21-
@python setup.py deps_table_update
21+
@python utils/update_dependency_table.py
2222

2323
deps_table_check_updated:
2424
@md5sum src/maxdiffusion/dependency_versions_table.py > md5sum.saved
25-
@python setup.py deps_table_update
25+
@python utils/update_dependency_table.py
2626
@md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1)
2727
@rm md5sum.saved
2828

README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2026/03/31`**: Wan2.2 SenCache inference is now supported for T2V and I2V (up to 1.4x speedup)
21+
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
22+
- **`2026/03/25`**: LTX-2 Video Inference is now supported
2023
- **`2026/01/29`**: Wan LoRA for inference is now supported
2124
- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
2225
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
@@ -49,6 +52,7 @@ MaxDiffusion supports
4952
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
5053
* Dreambooth training support for Stable Diffusion 1.x,2.x.
5154
* LTX-Video text2vid, img2vid (inference).
55+
* LTX-2 Video text2vid (inference).
5256
* Wan2.1 text2vid (training and inference).
5357
* Wan2.2 text2vid (inference).
5458

@@ -73,6 +77,7 @@ MaxDiffusion supports
7377
- [Inference](#inference)
7478
- [Wan](#wan-models)
7579
- [LTX-Video](#ltx-video)
80+
- [LTX-2 Video](#ltx-2-video)
7681
- [Flux](#flux)
7782
- [Fused Attention for GPU](#fused-attention-for-gpu)
7883
- [SDXL](#stable-diffusion-xl)
@@ -497,6 +502,33 @@ To generate images, run the following command:
497502

498503
Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.
499504

505+
## LTX-2 Video
506+
507+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
508+
509+
The following command will run LTX-2 T2V:
510+
511+
```bash
512+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
513+
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
514+
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
515+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
516+
--xla_tpu_overlap_compute_collective_tc=true \
517+
--xla_enable_async_all_reduce=true" \
518+
HF_HUB_ENABLE_HF_TRANSFER=1 \
519+
python src/maxdiffusion/generate_ltx2.py \
520+
src/maxdiffusion/configs/ltx2_video.yml \
521+
attention="flash" \
522+
num_inference_steps=40 \
523+
num_frames=121 \
524+
width=768 \
525+
height=512 \
526+
per_device_batch_size=.125 \
527+
ici_data_parallelism=2 \
528+
ici_context_parallelism=4 \
529+
run_name=ltx2-inference
530+
```
531+
500532
## Wan Models
501533

502534
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
@@ -540,6 +572,31 @@ To generate images, run the following command:
540572
* For Wan2.2 T2V, use `base_wan_27b.yml`.
541573
* For Wan2.2 I2V, use `base_wan_i2v_27b.yml`.
542574

575+
### Caching Mechanisms
576+
577+
Wan 2.x pipelines support several caching strategies to accelerate inference by skipping redundant transformer forward passes. These are **mutually exclusive** — enable only one at a time.
578+
579+
| Cache Type | Config Flag | Supported Pipelines | Speedup | Description |
580+
| --- | --- | --- | --- | --- |
581+
| **CFG Cache** | `use_cfg_cache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.2x | FasterCache-style: caches the unconditional branch and applies FFT frequency-domain compensation on skipped steps. |
582+
| **SenCache** | `use_sen_cache: True` | Wan 2.2 T2V/I2V | ~1.4x | Sensitivity-Aware Caching ([arXiv:2602.24208](https://arxiv.org/abs/2602.24208)): predicts output change via first-order sensitivity S = α_x·‖Δx‖ + α_t·\|Δt\|. Skips the full CFG forward pass when predicted change is below tolerance ε. |
583+
584+
To enable a caching mechanism, set the corresponding flag in your config YAML or pass it as a command-line override:
585+
586+
```bash
587+
# Example: enable SenCache for Wan 2.2 T2V
588+
python src/maxdiffusion/generate_wan.py \
589+
src/maxdiffusion/configs/base_wan_27b.yml \
590+
use_sen_cache=True \
591+
...
592+
593+
# Example: enable CFG Cache for Wan 2.2 I2V
594+
python src/maxdiffusion/generate_wan.py \
595+
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
596+
use_cfg_cache=True \
597+
...
598+
```
599+
543600
## Flux
544601

545602
First make sure you have permissions to access the Flux repos in Huggingface.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
--extra-index-url https://download.pytorch.org/whl/cpu
2+
absl-py
3+
aqtp
4+
chex
5+
datasets
6+
einops
7+
flax
8+
ftfy
9+
google-cloud-storage
10+
grain
11+
hf_transfer
12+
huggingface_hub
13+
imageio-ffmpeg
14+
imageio
15+
jax
16+
jaxlib
17+
Jinja2
18+
opencv-python-headless
19+
optax
20+
orbax-checkpoint
21+
parameterized
22+
Pillow
23+
pyink
24+
pylint
25+
pytest
26+
ruff
27+
scikit-image
28+
sentencepiece
29+
tensorboard-plugin-profile
30+
tensorboard
31+
tensorboardx
32+
tensorflow-datasets
33+
tensorflow
34+
tokamax
35+
tokenizers
36+
transformers<5.0.0
37+
38+
# pinning torch and torchvision to specific versions to avoid
39+
# installing GPU versions from PyPI when running seed-env
40+
torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
41+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl
42+
qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip
43+

0 commit comments

Comments
 (0)