Skip to content

Commit d9f784c

Browse files
committed
Merge branch 'main' into revert-for-preview
2 parents 014909e + 503e9d6 commit d9f784c

5 files changed

Lines changed: 25 additions & 55 deletions

File tree

docs/attention_blocks_flowchart.md

Lines changed: 0 additions & 30 deletions
This file was deleted.
-229 KB
Binary file not shown.

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504+
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
504505
flash_block_sizes = splash_attention_kernel.BlockSizes(
505506
block_q=config.flash_block_sizes["block_q"],
506507
block_kv_compute=config.flash_block_sizes["block_kv_compute"],

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,16 @@ def _tpu_flash_attention(
189189
if flash_block_sizes:
190190
block_sizes = flash_block_sizes
191191
else:
192-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
193192
block_sizes = splash_attention_kernel.BlockSizes(
194-
block_q=block_size_q,
193+
block_q=min(q_max_block_size, query.shape[2]),
195194
block_kv_compute=min(kv_max_block_size, key.shape[2]),
196195
block_kv=min(kv_max_block_size, key.shape[2]),
197-
block_q_dkv=block_size_q,
196+
block_q_dkv=min(q_max_block_size, query.shape[2]),
198197
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
199198
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
200-
block_q_dq=min(q_max_block_size, query.shape[2]),
201-
block_kv_dq=min(kv_max_block_size, query.shape[2]),
199+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
200+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
201+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
202202
)
203203
num_fsdp_shards = mesh.shape["fsdp"]
204204
query = _reshape_data_for_flash(query, heads)

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -181,25 +181,24 @@ def test_wan_block(self):
181181
assert dummy_output.shape == dummy_hidden_states.shape
182182

183183
def test_wan_attention(self):
184-
for attention_kernel in ["flash", "tokamax_flash"]:
185-
pyconfig.initialize(
186-
[
187-
None,
188-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
189-
f"attention={attention_kernel}"
190-
],
191-
unittest=True
192-
)
193-
config = pyconfig.config
194-
batch_size = 1
195-
channels = 16
196-
frames = 21
197-
height = 90
198-
width = 160
199-
hidden_states_shape = (batch_size, frames, height, width, channels)
200-
dummy_hidden_states = jnp.ones(hidden_states_shape)
201-
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
202-
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
184+
pyconfig.initialize(
185+
[
186+
None,
187+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
188+
],
189+
unittest=True,
190+
)
191+
config = pyconfig.config
192+
193+
batch_size = 1
194+
channels = 16
195+
frames = 21
196+
height = 90
197+
width = 160
198+
hidden_states_shape = (batch_size, frames, height, width, channels)
199+
dummy_hidden_states = jnp.ones(hidden_states_shape)
200+
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
201+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
203202

204203
key = jax.random.key(0)
205204
rngs = nnx.Rngs(key)
@@ -425,4 +424,4 @@ def test_quantize_transformer_disabled(self, mock_quantize_model):
425424

426425

427426
if __name__ == "__main__":
428-
absltest.main()
427+
absltest.main()

0 commit comments

Comments
 (0)