Skip to content

Commit fc77dc0

Browse files
fix some tests.
1 parent 5494644 commit fc77dc0

3 files changed

Lines changed: 31 additions & 17 deletions

File tree

src/maxdiffusion/tests/attention_test.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def setUp(self):
3737
def test_splash_attention(self):
3838
"""Test numerics of splash attention are equivalent to dot_product"""
3939

40-
pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml")], unittest=True)
40+
pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml"),
41+
'flash_block_sizes={"block_q" : 512, "block_kv_compute": 512, "block_kv": 512,'
42+
'"block_q_dkv": 512, "block_kv_dkv": 512, "block_kv_dkv_compute": 512,'
43+
'"block_q_dq": 512, "block_kv_dq": 512}',], unittest=True)
4144
config = pyconfig.config
4245

4346
batch = 8
@@ -47,15 +50,14 @@ def test_splash_attention(self):
4750

4851
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
4952
x = jax.random.normal(key1, (batch, length, heads * head_depth))
50-
5153
dot_product_attention = FlaxAttention(
5254
heads * head_depth,
5355
heads,
5456
head_depth,
5557
split_head_dim=True,
5658
attention_kernel="dot_product",
5759
mesh=None,
58-
dtype=jnp.bfloat16,
60+
dtype=jnp.bfloat16
5961
)
6062

6163
params = dot_product_attention.init(key2, x)["params"]
@@ -64,9 +66,16 @@ def test_splash_attention(self):
6466

6567
devices_array = max_utils.create_device_mesh(config)
6668
mesh = Mesh(devices_array, config.mesh_axes)
67-
69+
flash_block_sizes = max_utils.get_flash_block_sizes(config)
6870
splash_attention = FlaxAttention(
69-
heads * head_depth, heads, head_depth, split_head_dim=True, attention_kernel="flash", mesh=mesh, dtype=jnp.bfloat16
71+
heads * head_depth,
72+
heads,
73+
head_depth,
74+
split_head_dim=True,
75+
attention_kernel="flash",
76+
mesh=mesh,
77+
dtype=jnp.bfloat16,
78+
flash_block_sizes=flash_block_sizes
7079
)
7180

7281
params = splash_attention.init(key2, x)["params"]

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,8 @@ def test_wan_block(self):
163163
mesh=mesh,
164164
flash_block_sizes=flash_block_sizes,
165165
)
166-
167-
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
166+
with mesh:
167+
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
168168
assert dummy_output.shape == dummy_hidden_states.shape
169169

170170
def test_wan_attention(self):
@@ -210,10 +210,10 @@ def test_wan_attention(self):
210210

211211
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
212212
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
213-
214-
dummy_output = attention(
215-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
216-
)
213+
with mesh:
214+
dummy_output = attention(
215+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
216+
)
217217
assert dummy_output.shape == dummy_hidden_states_shape
218218

219219
# dot product
@@ -246,7 +246,7 @@ def test_wan_model(self):
246246
frames = 21
247247
height = 90
248248
width = 160
249-
hidden_states_shape = (batch_size, frames, height, width, channels)
249+
hidden_states_shape = (batch_size, channels, frames, height, width)
250250
dummy_hidden_states = jnp.ones(hidden_states_shape)
251251

252252
key = jax.random.key(0)
@@ -266,10 +266,14 @@ def test_wan_model(self):
266266

267267
dummy_timestep = jnp.ones((batch_size))
268268
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))
269-
270-
dummy_output = wan_model(
271-
hidden_states=dummy_hidden_states, timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states
272-
)
269+
with mesh:
270+
dummy_output = wan_model(
271+
hidden_states=dummy_hidden_states,
272+
timestep=dummy_timestep,
273+
encoder_hidden_states=dummy_encoder_hidden_states,
274+
is_uncond=jnp.array(True, dtype=jnp.bool_),
275+
slg_mask=jnp.zeros(40, dtype=jnp.bool_)
276+
)
273277
assert dummy_output.shape == hidden_states_shape
274278

275279

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax
2323
import jax.tree_util as jtu
2424
from flax import nnx
25+
from flax.linen import partitioning as nn_partitioning
2526
from ..schedulers import FlaxEulerDiscreteScheduler
2627
from .. import max_utils, max_logging, train_utils, maxdiffusion_utils
2728
from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT)
@@ -115,7 +116,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data):
115116
for step in np.arange(start_step, self.config.max_train_steps):
116117
if self.config.enable_profiler and step == first_profiling_step:
117118
max_utils.activate_profiler(self.config)
118-
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh:
119+
with jax.profiler.StepTraceAnnotation("train", step_num=step), pipeline.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
119120
state, train_metric, rng = p_train_step(state, graphdef, data, rng)
120121

121122
new_time = datetime.datetime.now()

0 commit comments

Comments
 (0)