Skip to content

Commit 4325325

Browse files
add wan mid block vae test
1 parent 9b42117 commit 4325325

2 files changed

Lines changed: 31 additions & 10 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,20 @@ def __init__(
414414
num_layers: int = 1
415415
):
416416
self.dim = dim
417+
resnets = [WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity)]
418+
attentions = []
419+
for _ in range(num_layers):
420+
attentions.append(WanAttentionBlock(dim=dim, rngs=rngs))
421+
resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity))
422+
self.attentions = attentions
423+
self.resnets = resnets
417424

418425
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
426+
x = self.resnets[0](x)
427+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
428+
if attn is not None:
429+
x = attn(x)
430+
x = resnet(x)
419431
return x
420432

421433
class WanUpBlock(nnx.Module):
@@ -519,6 +531,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
519531
# (1, 1, 480, 720, 96)
520532
for layer in self.down_blocks:
521533
x = layer(x)
534+
535+
x = self.mid_block(x)
522536
breakpoint()
523537
return x
524538

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
WanUpsample,
3131
AutoencoderKLWan,
3232
WanEncoder3d,
33+
WanMidBlock,
3334
WanResidualBlock,
3435
WanRMS_norm,
3536
WanResample,
@@ -373,6 +374,22 @@ def test_wan_attention(self):
373374
output = wan_attention(dummy_input)
374375
assert output.shape == input_shape
375376

377+
def test_wan_midblock(self):
378+
key = jax.random.key(0)
379+
rngs = nnx.Rngs(key)
380+
batch = 1
381+
t = 1
382+
dim = 384
383+
height = 60
384+
width = 90
385+
input_shape = (batch, t, height, width, dim)
386+
wan_midblock = WanMidBlock(
387+
dim=dim, rngs=rngs
388+
)
389+
dummy_input = jnp.ones(input_shape)
390+
output = wan_midblock(dummy_input)
391+
assert output.shape == input_shape
392+
376393
def test_wan_encode(self):
377394
key = jax.random.key(0)
378395
rngs = nnx.Rngs(key)
@@ -392,16 +409,6 @@ def test_wan_encode(self):
392409
attn_scales=attn_scales,
393410
temperal_downsample=temperal_downsample,
394411
)
395-
# wan_encoder = WanEncoder3d(
396-
# rngs=rngs,
397-
# dim=dim,
398-
# z_dim=z_dim,
399-
# dim_mult=dim_mult,
400-
# num_res_blocks=num_res_blocks,
401-
# attn_scales=attn_scales,
402-
# temperal_downsample=temperal_downsample,
403-
# non_linearity=nonlinearity
404-
# )
405412
batch = 1
406413
channels = 3
407414
t = 49

0 commit comments

Comments
 (0)