Skip to content

Commit 8d8d1a0

Browse files
committed
Test on attention type and automatically modify flash block sizes object when 'tokamax_flash' requested
Signed-off-by: Kunjan Patel <kunjanp@google.com>
1 parent becfc88 commit 8d8d1a0

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,19 @@ def test_wan_attention(self):
220220
rngs = nnx.Rngs(key)
221221
devices_array = create_device_mesh(config)
222222

223-
224-
mesh = Mesh(devices_array, config.mesh_axes)
223+
mesh_axes = ['data', 'fsdp', 'tensor']
224+
mesh = Mesh(devices_array, mesh_axes)
225225
batch_size = 1
226226
query_dim = 5120
227227
for attention_kernel in ["flash", "tokamax_flash"]:
228+
pyconfig.initialize(
229+
[
230+
None,
231+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
232+
f"attention={attention_kernel}"
233+
]
234+
)
235+
config = pyconfig.config
228236
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
229237
config.attention = attention_kernel
230238
flash_block_sizes = get_flash_block_sizes(config)

0 commit comments

Comments
 (0)