Skip to content

Commit 9b42117

Browse files
add wan vae attention test
1 parent 0ec4b02 commit 9b42117

2 files changed

Lines changed: 57 additions & 6 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,45 @@ def __init__(
364364
rngs: nnx.Rngs
365365
):
366366
self.dim = dim
367+
self.norm = WanRMS_norm(rngs=rngs, dim=dim, channel_first=False)
368+
self.to_qkv = nnx.Conv(
369+
in_features=dim,
370+
out_features=dim * 3,
371+
kernel_size=1,
372+
rngs=rngs
373+
)
374+
self.proj = nnx.Conv(
375+
in_features=dim,
376+
out_features=dim,
377+
kernel_size=1,
378+
rngs=rngs
379+
)
367380

368381
def __call__(self, x: jax.Array):
369-
return x
382+
batch_size, time, height, width, channels = x.shape
383+
identity = x
384+
385+
x = x.reshape(batch_size * time, height, width, channels)
386+
x = self.norm(x)
387+
388+
qkv = self.to_qkv(x) # Output: (N*D, H, W, C * 3)
389+
390+
qkv = qkv.reshape(batch_size*time, 1, channels * 3, -1)
391+
qkv = jnp.transpose(qkv, (0, 1, 3, 2))
392+
q, k, v = jnp.split(qkv, 3, axis=-1)
393+
394+
x = jax.nn.dot_product_attention(q, k, v)
395+
x = jnp.squeeze(x, 1).reshape(batch_size * time, height, width, channels)
396+
397+
#output projection
398+
x = self.proj(x)
399+
400+
# Reshape back
401+
x = x.reshape(batch_size, time, height, width, channels)
402+
403+
return x + identity
404+
405+
370406

371407
class WanMidBlock(nnx.Module):
372408
def __init__(

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
WanResidualBlock,
3434
WanRMS_norm,
3535
WanResample,
36-
ZeroPaddedConv2D
36+
ZeroPaddedConv2D,
37+
WanAttentionBlock
3738
)
3839

3940
CACHE_T = 2
@@ -322,7 +323,7 @@ def test_3d_conv(self):
322323
def test_wan_residual(self):
323324
key = jax.random.key(0)
324325
rngs = nnx.Rngs(key)
325-
# one test
326+
# --- Test Case 1: same in/out dim ---
326327
in_dim = out_dim = 96
327328
batch = 1
328329
t = 1
@@ -341,7 +342,7 @@ def test_wan_residual(self):
341342
dummy_output = wan_residual_block(dummy_input)
342343
assert dummy_output.shape == expected_output_shape
343344

344-
# another test
345+
# --- Test Case 1: different in/out dim ---
345346
in_dim = 96
346347
out_dim = 196
347348
expected_output_shape = (batch, t, height, width, out_dim)
@@ -355,8 +356,22 @@ def test_wan_residual(self):
355356
dummy_output = wan_residual_block(dummy_input)
356357
assert dummy_output.shape == expected_output_shape
357358

358-
359-
359+
def test_wan_attention(self):
360+
key = jax.random.key(0)
361+
rngs = nnx.Rngs(key)
362+
dim = 384
363+
batch = 1
364+
t = 1
365+
height = 60
366+
width = 90
367+
input_shape=(batch, t, height, width, dim)
368+
wan_attention = WanAttentionBlock(
369+
dim=dim,
370+
rngs=rngs
371+
)
372+
dummy_input = jnp.ones(input_shape)
373+
output = wan_attention(dummy_input)
374+
assert output.shape == input_shape
360375

361376
def test_wan_encode(self):
362377
key = jax.random.key(0)

0 commit comments

Comments
 (0)