Skip to content

Commit e167476

Browse files
committed
Merge branch 'main' into revert-for-preview
2 parents 014909e + 503e9d6 commit e167476

5 files changed

Lines changed: 63 additions & 94 deletions

File tree

docs/attention_blocks_flowchart.md

Lines changed: 0 additions & 30 deletions
This file was deleted.
-229 KB
Binary file not shown.

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def get_flash_block_sizes(config):
501501
"""Create custom flash attention BlockSizes."""
502502
flash_block_sizes = None
503503
if len(config.flash_block_sizes.keys()) > 0:
504+
use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False)
504505
flash_block_sizes = splash_attention_kernel.BlockSizes(
505506
block_q=config.flash_block_sizes["block_q"],
506507
block_kv_compute=config.flash_block_sizes["block_kv_compute"],

src/maxdiffusion/models/attention_flax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,16 +189,16 @@ def _tpu_flash_attention(
189189
if flash_block_sizes:
190190
block_sizes = flash_block_sizes
191191
else:
192-
block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size
193192
block_sizes = splash_attention_kernel.BlockSizes(
194-
block_q=block_size_q,
193+
block_q=min(q_max_block_size, query.shape[2]),
195194
block_kv_compute=min(kv_max_block_size, key.shape[2]),
196195
block_kv=min(kv_max_block_size, key.shape[2]),
197-
block_q_dkv=block_size_q,
196+
block_q_dkv=min(q_max_block_size, query.shape[2]),
198197
block_kv_dkv=min(kv_max_block_size, key.shape[2]),
199198
block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]),
200-
block_q_dq=min(q_max_block_size, query.shape[2]),
201-
block_kv_dq=min(kv_max_block_size, query.shape[2]),
199+
block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq,
200+
block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]),
201+
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
202202
)
203203
num_fsdp_shards = mesh.shape["fsdp"]
204204
query = _reshape_data_for_flash(query, heads)

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 57 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from unittest.mock import Mock, patch, call
2323
from absl.testing import absltest
2424
from flax import nnx
25+
from flax.linen import partitioning as nn_partitioning
2526
from jax.sharding import Mesh
2627

2728
from .. import pyconfig
@@ -163,43 +164,41 @@ def test_wan_block(self):
163164
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
164165

165166
dummy_temb = jnp.ones((batch_size, 6, dim))
166-
167-
wan_block = WanTransformerBlock(
168-
rngs=rngs,
169-
dim=dim,
170-
ffn_dim=ffn_dim,
171-
num_heads=num_heads,
172-
qk_norm=qk_norm,
173-
cross_attn_norm=cross_attn_norm,
174-
eps=eps,
175-
attention="flash",
176-
mesh=mesh,
177-
flash_block_sizes=flash_block_sizes,
178-
)
179-
with mesh:
167+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
168+
wan_block = WanTransformerBlock(
169+
rngs=rngs,
170+
dim=dim,
171+
ffn_dim=ffn_dim,
172+
num_heads=num_heads,
173+
qk_norm=qk_norm,
174+
cross_attn_norm=cross_attn_norm,
175+
eps=eps,
176+
attention="flash",
177+
mesh=mesh,
178+
flash_block_sizes=flash_block_sizes,
179+
)
180180
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
181181
assert dummy_output.shape == dummy_hidden_states.shape
182182

183183
def test_wan_attention(self):
184-
for attention_kernel in ["flash", "tokamax_flash"]:
185-
pyconfig.initialize(
186-
[
187-
None,
188-
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
189-
f"attention={attention_kernel}"
190-
],
191-
unittest=True
192-
)
193-
config = pyconfig.config
194-
batch_size = 1
195-
channels = 16
196-
frames = 21
197-
height = 90
198-
width = 160
199-
hidden_states_shape = (batch_size, frames, height, width, channels)
200-
dummy_hidden_states = jnp.ones(hidden_states_shape)
201-
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
202-
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
184+
pyconfig.initialize(
185+
[
186+
None,
187+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
188+
],
189+
unittest=True,
190+
)
191+
config = pyconfig.config
192+
193+
batch_size = 1
194+
channels = 16
195+
frames = 21
196+
height = 90
197+
width = 160
198+
hidden_states_shape = (batch_size, frames, height, width, channels)
199+
dummy_hidden_states = jnp.ones(hidden_states_shape)
200+
wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024)
201+
dummy_rotary_emb = wan_rot_embed(dummy_hidden_states)
203202

204203
key = jax.random.key(0)
205204
rngs = nnx.Rngs(key)
@@ -210,40 +209,39 @@ def test_wan_attention(self):
210209
mesh = Mesh(devices_array, config.mesh_axes)
211210
batch_size = 1
212211
query_dim = 5120
213-
attention = FlaxWanAttention(
214-
rngs=rngs,
215-
query_dim=query_dim,
216-
heads=40,
217-
dim_head=128,
218-
attention_kernel="flash",
219-
mesh=mesh,
220-
flash_block_sizes=flash_block_sizes,
221-
)
222-
223-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
224-
225-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
226-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
227-
with mesh:
228-
dummy_output = attention(
229-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
230-
)
231-
assert dummy_output.shape == dummy_hidden_states_shape
232-
233-
# dot product
234-
try:
212+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
235213
attention = FlaxWanAttention(
236214
rngs=rngs,
237215
query_dim=query_dim,
238216
heads=40,
239217
dim_head=128,
240-
attention_kernel="dot_product",
241-
split_head_dim=True,
218+
attention_kernel="flash",
242219
mesh=mesh,
243220
flash_block_sizes=flash_block_sizes,
244221
)
245-
except NotImplementedError:
246-
pass
222+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
223+
224+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
225+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
226+
dummy_output = attention(
227+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
228+
)
229+
assert dummy_output.shape == dummy_hidden_states_shape
230+
231+
# dot product
232+
try:
233+
attention = FlaxWanAttention(
234+
rngs=rngs,
235+
query_dim=query_dim,
236+
heads=40,
237+
dim_head=128,
238+
attention_kernel="dot_product",
239+
split_head_dim=True,
240+
mesh=mesh,
241+
flash_block_sizes=flash_block_sizes,
242+
)
243+
except NotImplementedError:
244+
pass
247245

248246
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
249247
def test_wan_model(self):

0 commit comments

Comments
 (0)