Skip to content

Commit 2baa447

Browse files
committed
passing two options as configs
1 parent 3f752cb commit 2baa447

8 files changed

Lines changed: 56 additions & 2 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
65+
use_base2_exp: False
66+
use_experimental_scheduler: False
6567
flash_min_seq_length: 0
6668

6769
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: False
65+
use_experimental_scheduler: False
6466
flash_min_seq_length: 0
6567

6668
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ jit_initializers: True
6262
from_pt: True
6363
split_head_dim: True
6464
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
65+
use_base2_exp: False
66+
use_experimental_scheduler: False
6567
flash_min_seq_length: 4096
6668
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
6769
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: False
65+
use_experimental_scheduler: False
6466
flash_min_seq_length: 4096
6567
dropout: 0.0
6668

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ jit_initializers: True
6161
from_pt: True
6262
split_head_dim: True
6363
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+
use_base2_exp: False
65+
use_experimental_scheduler: False
6466
flash_min_seq_length: 4096
6567
dropout: 0.0
6668

src/maxdiffusion/models/attention_flax.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def convert_to_tokamax_splash_config(
272272
attn_logits_soft_cap: float | None = None,
273273
fuse_reciprocal: bool = True,
274274
use_base2_exp: bool = False,
275+
use_experimental_scheduler: bool = False,
275276
max_logit_const: float | None = None,
276277
interpret: bool = False,
277278
dq_reduction_steps: int | None = None,
@@ -294,6 +295,7 @@ def convert_to_tokamax_splash_config(
294295
attn_logits_soft_cap=attn_logits_soft_cap,
295296
fuse_reciprocal=fuse_reciprocal,
296297
use_base2_exp=use_base2_exp,
298+
use_experimental_scheduler=use_experimental_scheduler,
297299
max_logit_const=max_logit_const,
298300
interpret=interpret,
299301
dq_reduction_steps=dq_reduction_steps,
@@ -314,6 +316,8 @@ def _tpu_flash_attention(
314316
mask_padding_tokens: bool = True,
315317
residual_checkpoint_name: str | None = None,
316318
attention_mask: jax.Array = None,
319+
use_base2_exp: bool = False,
320+
use_experimental_scheduler: bool = False,
317321
) -> jax.Array:
318322
"""TPU Flash Attention"""
319323

@@ -399,7 +403,12 @@ def wrap_flash_attention(query, key, value):
399403
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
400404
mask=mask,
401405
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
402-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
406+
config=convert_to_tokamax_splash_config(
407+
block_sizes,
408+
residual_checkpoint_name=residual_checkpoint_name,
409+
use_base2_exp=use_base2_exp,
410+
use_experimental_scheduler=use_experimental_scheduler,
411+
),
403412
save_residuals=False,
404413
)
405414
elif attention_kernel == "tokamax_ring":
@@ -409,7 +418,12 @@ def wrap_flash_attention(query, key, value):
409418
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
410419
mask=mask,
411420
is_mqa=False,
412-
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
421+
config=convert_to_tokamax_splash_config(
422+
block_sizes,
423+
residual_checkpoint_name=residual_checkpoint_name,
424+
use_base2_exp=use_base2_exp,
425+
use_experimental_scheduler=use_experimental_scheduler,
426+
),
413427
save_residuals=False,
414428
ring_axis="context",
415429
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
@@ -741,6 +755,8 @@ def _apply_attention(
741755
mask_padding_tokens: bool = True,
742756
residual_checkpoint_name: str | None = None,
743757
attention_mask: Array = None,
758+
use_base2_exp: bool = False,
759+
use_experimental_scheduler: bool = False,
744760
):
745761
"""Routes to different attention kernels."""
746762
_check_attention_inputs(query, key, value)
@@ -789,6 +805,8 @@ def _apply_attention(
789805
mask_padding_tokens=mask_padding_tokens,
790806
residual_checkpoint_name=residual_checkpoint_name,
791807
attention_mask=attention_mask,
808+
use_base2_exp=use_base2_exp,
809+
use_experimental_scheduler=use_experimental_scheduler,
792810
)
793811
elif "ring" in attention_kernel:
794812
return _tpu_flash_attention(
@@ -983,8 +1001,12 @@ def __init__(
9831001
quant: Quant = None,
9841002
mask_padding_tokens: bool = True,
9851003
residual_checkpoint_name: str | None = None,
1004+
use_base2_exp: bool = False,
1005+
use_experimental_scheduler: bool = False,
9861006
):
9871007
self.dpa_layer = None
1008+
self.use_base2_exp = use_base2_exp
1009+
self.use_experimental_scheduler = use_experimental_scheduler
9881010
if attention_kernel == "cudnn_flash_te":
9891011
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
9901012

@@ -1045,6 +1067,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
10451067
mask_padding_tokens=self.mask_padding_tokens,
10461068
residual_checkpoint_name=self.residual_checkpoint_name,
10471069
attention_mask=attention_mask,
1070+
use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False,
1071+
use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False,
10481072
)
10491073

10501074

@@ -1063,6 +1087,8 @@ class AttentionOp(nn.Module):
10631087
flash_block_sizes: BlockSizes = None
10641088
dtype: DType = jnp.float32
10651089
quant: Quant = None
1090+
use_base2_exp: bool = False
1091+
use_experimental_scheduler: bool = False
10661092

10671093
def setup(self):
10681094
self.dpa_layer = None
@@ -1108,6 +1134,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
11081134
flash_block_sizes=self.flash_block_sizes,
11091135
dpa_layer=self.dpa_layer,
11101136
attention_mask=attention_mask,
1137+
use_base2_exp=self.use_base2_exp,
1138+
use_experimental_scheduler=self.use_experimental_scheduler,
11111139
)
11121140

11131141

@@ -1144,6 +1172,8 @@ def __init__(
11441172
enable_jax_named_scopes: bool = False,
11451173
added_kv_proj_dim: Optional[int] = None, # New for I2V
11461174
image_seq_len: Optional[int] = None, # New for I2V
1175+
use_base2_exp: bool = False,
1176+
use_experimental_scheduler: bool = False,
11471177
):
11481178
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
11491179
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
@@ -1186,6 +1216,8 @@ def __init__(
11861216
quant=quant,
11871217
mask_padding_tokens=mask_padding_tokens,
11881218
residual_checkpoint_name=residual_checkpoint_name,
1219+
use_base2_exp=use_base2_exp,
1220+
use_experimental_scheduler=use_experimental_scheduler,
11891221
)
11901222
# None axes corresponds to the stacked weights across all blocks
11911223
# because of the use of nnx.vmap and nnx.scan.

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,8 @@ def __init__(
291291
dropout: float = 0.0,
292292
mask_padding_tokens: bool = True,
293293
enable_jax_named_scopes: bool = False,
294+
use_base2_exp: bool = False,
295+
use_experimental_scheduler: bool = False,
294296
):
295297
self.enable_jax_named_scopes = enable_jax_named_scopes
296298

@@ -315,6 +317,8 @@ def __init__(
315317
mask_padding_tokens=mask_padding_tokens,
316318
residual_checkpoint_name="self_attn",
317319
enable_jax_named_scopes=enable_jax_named_scopes,
320+
use_base2_exp=use_base2_exp,
321+
use_experimental_scheduler=use_experimental_scheduler,
318322
)
319323

320324
# 1. Cross-attention
@@ -339,6 +343,8 @@ def __init__(
339343
mask_padding_tokens=mask_padding_tokens,
340344
residual_checkpoint_name="cross_attn",
341345
enable_jax_named_scopes=enable_jax_named_scopes,
346+
use_base2_exp=use_base2_exp,
347+
use_experimental_scheduler=use_experimental_scheduler,
342348
)
343349
assert cross_attn_norm is True
344350
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
@@ -486,6 +492,8 @@ def __init__(
486492
mask_padding_tokens: bool = True,
487493
scan_layers: bool = True,
488494
enable_jax_named_scopes: bool = False,
495+
use_base2_exp: bool = False,
496+
use_experimental_scheduler: bool = False,
489497
):
490498
inner_dim = num_attention_heads * attention_head_dim
491499
out_channels = out_channels or in_channels
@@ -547,6 +555,8 @@ def init_block(rngs):
547555
enable_jax_named_scopes=enable_jax_named_scopes,
548556
added_kv_proj_dim=added_kv_proj_dim,
549557
image_seq_len=image_seq_len,
558+
use_base2_exp=use_base2_exp,
559+
use_experimental_scheduler=use_experimental_scheduler,
550560
)
551561

552562
self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
139139
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
140140
wan_config["scan_layers"] = config.scan_layers
141141
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes
142+
wan_config["use_base2_exp"] = config.use_base2_exp
143+
wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler
142144

143145
# 2. eval_shape - will not use flops or create weights on device
144146
# thus not using HBM memory.

0 commit comments

Comments
 (0)