Skip to content

Commit 69d2a30

Browse files
committed
Merge main
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent 1d21a53 commit 69d2a30

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def wrap_flash_attention(query, key, value):
302302
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
303303
mask=mask,
304304
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
305-
config=convert_to_tokamax_splash_config(block_sizes),
305+
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
306306
save_residuals=True if attention_kernel == "ring" else False,
307307
)
308308
else:
@@ -312,6 +312,7 @@ def wrap_flash_attention(query, key, value):
312312
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
313313
block_sizes=block_sizes,
314314
save_residuals=True if attention_kernel == "ring" else False,
315+
residual_checkpoint_name=residual_checkpoint_name
315316
)
316317
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
317318

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def __init__(
285285
attention_kernel=attention,
286286
dropout=dropout,
287287
is_self_attention=True,
288-
mask_padding_tokens=mask_padding_tokens
288+
mask_padding_tokens=mask_padding_tokens,
289289
residual_checkpoint_name="self_attn",
290290
)
291291

@@ -306,7 +306,7 @@ def __init__(
306306
attention_kernel=attention,
307307
dropout=dropout,
308308
is_self_attention=False,
309-
mask_padding_tokens=mask_padding_tokens
309+
mask_padding_tokens=mask_padding_tokens,
310310
residual_checkpoint_name="cross_attn",
311311
)
312312
assert cross_attn_norm is True

0 commit comments

Comments
 (0)