Skip to content

Commit ddbce4a

Browse files
committed
Merge origin/main into torchax_attention and make ulysses_custom conditional
2 parents 8cd6da3 + c98002f commit ddbce4a

15 files changed

Lines changed: 318 additions & 54 deletions

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -628,16 +628,16 @@ To generate images, run the following command:
628628
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
629629
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
630630
| -- | -- | -- | -- | -- | -- |
631-
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 |
632-
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** |
633-
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 |
634-
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** |
631+
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | **249.3** |
632+
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | 252.4 |
633+
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | **194.4** |
634+
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | 201.7 |
635635

636636
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
637637
| -- | -- | -- | -- | -- | -- |
638-
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 |
639-
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** |
640-
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** |
638+
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | **127.1** |
639+
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | 137.2 |
640+
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **106.0** |
641641
| v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 |
642642

643643
(* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.)

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
65+
use_base2_exp: True
66+
use_experimental_scheduler: True
6567
flash_min_seq_length: 0
6668

6769
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: True
65+
use_experimental_scheduler: True
6466
flash_min_seq_length: 0
6567

6668
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
65+
use_base2_exp: True
66+
use_experimental_scheduler: True
6567
flash_min_seq_length: 4096
6668
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6769
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: True
65+
use_experimental_scheduler: True
6466
flash_min_seq_length: 4096
6567
dropout: 0.0
6668

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: True
65+
use_experimental_scheduler: True
6466
flash_min_seq_length: 4096
6567
dropout: 0.0
6668

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
302302
f"{'=' * 50}"
303303
)
304304

305-
s0 = time.perf_counter()
306-
if max_utils.profiler_enabled(config):
307-
with max_utils.Profiler(config):
308-
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
309-
generation_time_with_profiler = time.perf_counter() - s0
310-
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
311-
if writer and jax.process_index() == 0:
312-
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
305+
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
313306

314307
return saved_video_path
315308

0 commit comments

Comments
 (0)