Skip to content

Commit 859e4b3

Browse files
committed
debug
1 parent 9de3833 commit 859e4b3

3 files changed

Lines changed: 4 additions & 2 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
64-
flash_min_seq_length: 4096
64+
flash_min_seq_length: 0
6565
dropout: 0.1
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring
64-
flash_min_seq_length: 4096
64+
flash_min_seq_length: 0
6565
dropout: 0.1
6666

6767
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def _tpu_flash_attention(
225225
attention_mask: jax.Array = None,
226226
) -> jax.Array:
227227
"""TPU Flash Attention"""
228+
jax.debug.print("USing FLASH ATTENTION")
228229

229230
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
230231
# This is the case for cross-attn.
@@ -444,6 +445,7 @@ def _apply_attention_dot(
444445
float32_qk_product: bool,
445446
use_memory_efficient_attention: bool,
446447
):
448+
jax.debug.print("Using DOT PRODUCT ATTENTION")
447449
"""Apply Attention."""
448450
if split_head_dim:
449451
b = key.shape[0]

0 commit comments

Comments
 (0)