Skip to content

Commit 446ec05

Browse files
committed
Feat: Ring attention kernel and VAE optimization
1 parent 18f6f0f commit 446ec05

37 files changed

Lines changed: 7546 additions & 124 deletions

.github/workflows/UnitTests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ jobs:
5757
- name: PyTest
5858
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
5959
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
60-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61-
# add_pull_ready:
60+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
61+
# add_pull_ready
6262
# if: github.ref != 'refs/heads/main'
6363
# permissions:
6464
# checks: read
6565
# pull-requests: write
6666
# needs: build
67-
# uses: ./.github/workflows/AddLabel.yml
67+
# uses: ./.github/workflows/AddLabel.yml

README.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)
1818

1919
# What's new?
20+
- **`2026/04/16`**: Support for Tokamax Ring Attention kernel is now added.
2021
- **`2026/03/31`**: Wan2.2 SenCache inference is now supported for T2V and I2V (up to 1.4x speedup)
2122
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
2223
- **`2026/03/25`**: LTX-2 Video Inference is now supported
@@ -603,6 +604,24 @@ To generate images, run the following command:
603604
...
604605
```
605606

607+
### Ring Attention
608+
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
609+
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
610+
| -- | -- | -- | -- | -- | -- |
611+
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 |
612+
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** |
613+
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 |
614+
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** |
615+
616+
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
617+
| -- | -- | -- | -- | -- | -- |
618+
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 |
619+
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** |
620+
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** |
621+
| v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 |
622+
623+
(* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.)
624+
606625
## Flux
607626

608627
First make sure you have permissions to access the Flux repos in Huggingface.
@@ -780,4 +799,4 @@ This script will automatically format your code with `pyink` and help you identi
780799
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
781800
782801
## Profiling
783-
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
802+
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47+
vae_spatial: -1 # default to total_device * 2 // (dp)
4748

4849
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -60,7 +61,7 @@ jit_initializers: True
6061
# Set true to load weights from pytorch
6162
from_pt: True
6263
split_head_dim: True
63-
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
64+
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring
6465
flash_min_seq_length: 0
6566

6667
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
@@ -180,6 +181,19 @@ logical_axis_rules: [
180181
['out_channels', 'tensor'],
181182
['conv_out', 'context'],
182183
]
184+
vae_logical_axis_rules: [
185+
['activation_batch', 'redundant'],
186+
['activation_length', 'vae_spatial'],
187+
['activation_heads', null],
188+
['activation_kv_length', null],
189+
['embed', null],
190+
['heads', null],
191+
['norm', null],
192+
['conv_batch', 'redundant'],
193+
['out_channels', 'vae_spatial'],
194+
['conv_out', 'vae_spatial'],
195+
['conv_in', 'vae_spatial'],
196+
]
183197
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
184198

185199
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2023 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -157,6 +157,19 @@ logical_axis_rules: [
157157
['out_channels', 'tensor'],
158158
['conv_out', 'context'],
159159
]
160+
vae_logical_axis_rules: [
161+
['activation_batch', 'redundant'],
162+
['activation_length', 'vae_spatial'],
163+
['activation_heads', null],
164+
['activation_kv_length', null],
165+
['embed', null],
166+
['heads', null],
167+
['norm', null],
168+
['conv_batch', 'redundant'],
169+
['out_channels', 'vae_spatial'],
170+
['conv_out', 'vae_spatial'],
171+
['conv_in', 'vae_spatial'],
172+
]
160173
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
161174

162175
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444

4545
# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646
replicate_vae: False
47+
vae_spatial: 1
4748

4849
# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950
# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -168,6 +169,19 @@ logical_axis_rules: [
168169
['out_channels', 'tensor'],
169170
['conv_out', 'context'],
170171
]
172+
vae_logical_axis_rules: [
173+
['activation_batch', 'redundant'],
174+
['activation_length', 'vae_spatial'],
175+
['activation_heads', null],
176+
['activation_kv_length', null],
177+
['embed', null],
178+
['heads', null],
179+
['norm', null],
180+
['conv_batch', 'redundant'],
181+
['out_channels', 'vae_spatial'],
182+
['conv_out', 'vae_spatial'],
183+
['conv_in', 'vae_spatial'],
184+
]
171185
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
172186

173187
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,19 @@ logical_axis_rules: [
163163
['out_channels', 'tensor'],
164164
['conv_out', 'context'],
165165
]
166+
vae_logical_axis_rules: [
167+
['activation_batch', 'redundant'],
168+
['activation_length', 'vae_spatial'],
169+
['activation_heads', null],
170+
['activation_kv_length', null],
171+
['embed', null],
172+
['heads', null],
173+
['norm', null],
174+
['conv_batch', 'redundant'],
175+
['out_channels', 'vae_spatial'],
176+
['conv_out', 'vae_spatial'],
177+
['conv_in', 'vae_spatial'],
178+
]
166179
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
167180

168181
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,19 @@ logical_axis_rules: [
164164
['out_channels', 'tensor'],
165165
['conv_out', 'context'],
166166
]
167+
vae_logical_axis_rules: [
168+
['activation_batch', 'redundant'],
169+
['activation_length', 'vae_spatial'],
170+
['activation_heads', null],
171+
['activation_kv_length', null],
172+
['embed', null],
173+
['heads', null],
174+
['norm', null],
175+
['conv_batch', 'redundant'],
176+
['out_channels', 'vae_spatial'],
177+
['conv_out', 'vae_spatial'],
178+
['conv_in', 'vae_spatial'],
179+
]
167180
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
168181

169182
# One axis for each parallelism type may hold a placeholder (-1)

src/maxdiffusion/configuration_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def load_config(
394394
proxies=proxies,
395395
resume_download=resume_download,
396396
local_files_only=local_files_only,
397-
use_auth_token=use_auth_token,
397+
token=use_auth_token,
398398
user_agent=user_agent,
399399
subfolder=subfolder,
400400
revision=revision,

src/maxdiffusion/kernels/__init__.py

Whitespace-only changes.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Splash Attention kernels."""

0 commit comments

Comments
 (0)