Skip to content

Commit 0ec4b02

Browse files
Residual block test
1 parent aeabe27 commit 0ec4b02

2 files changed

Lines changed: 173 additions & 37 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 113 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,49 @@ def __init__(
313313
dropout: float = 0.0,
314314
non_linearity: str = "silu",
315315
):
316-
pass
316+
self.nonlinearity = get_activation(non_linearity)
317+
318+
# layers
319+
self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False)
320+
self.conv1 = WanCausalConv3d(
321+
rngs=rngs,
322+
in_channels=in_dim,
323+
out_channels=out_dim,
324+
kernel_size=3,
325+
padding=1
326+
)
327+
self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False)
328+
self.dropout = nnx.Dropout(dropout, rngs=rngs)
329+
self.conv2 = WanCausalConv3d(
330+
rngs=rngs,
331+
in_channels=out_dim,
332+
out_channels=out_dim,
333+
kernel_size=3,
334+
padding=1
335+
)
336+
self.conv_shortcut = WanCausalConv3d(
337+
rngs=rngs,
338+
in_channels=in_dim,
339+
out_channels=out_dim,
340+
kernel_size=1
341+
) if in_dim != out_dim else Identity()
342+
317343

318344
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
319-
return x
345+
# Apply shortcut connection
346+
#breakpoint()
347+
h = self.conv_shortcut(x)
348+
349+
x = self.norm1(x)
350+
x = self.nonlinearity(x)
351+
x = self.conv1(x)
352+
353+
x = self.norm2(x)
354+
x = self.nonlinearity(x)
355+
x = self.dropout(x)
356+
x = self.conv2(x)
357+
358+
return x + h
320359

321360
class WanAttentionBlock(nnx.Module):
322361
def __init__(
@@ -397,11 +436,11 @@ def __init__(
397436

398437
# init block
399438
self.conv_in = WanCausalConv3d(
439+
rngs=rngs,
400440
in_channels=3,
401441
out_channels=dims[0],
402442
kernel_size=3,
403443
padding=1,
404-
rngs=rngs
405444
)
406445

407446
# downsample blocks
@@ -439,6 +478,12 @@ def __init__(
439478
)
440479

441480
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
481+
# (1, 1, 480, 720, 3)
482+
x = self.conv_in(x)
483+
# (1, 1, 480, 720, 96)
484+
for layer in self.down_blocks:
485+
x = layer(x)
486+
breakpoint()
442487
return x
443488

444489
class WanDecoder3d(nnx.Module):
@@ -480,7 +525,13 @@ def __init__(
480525
scale = 1.0 / 2 ** (len(dim_mult) - 2)
481526

482527
# init block
483-
self.conv_in = WanCausalConv3d(in_channels=z_dim, out_channels=dims[0], kernel_size=3, padding=1, rngs=rngs)
528+
self.conv_in = WanCausalConv3d(
529+
rngs=rngs,
530+
in_channels=z_dim,
531+
out_channels=dims[0],
532+
kernel_size=3,
533+
padding=1
534+
)
484535

485536
# middle_blocks
486537
self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1)
@@ -516,7 +567,13 @@ def __init__(
516567
# output blocks
517568
self.norm_out = nnx.RMSNorm(num_features=out_dim, )
518569
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs)
519-
self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs)
570+
self.conv_out = WanCausalConv3d(
571+
rngs=rngs,
572+
in_channels=out_dim,
573+
out_channels=3,
574+
kernel_size=3,
575+
padding=1
576+
)
520577

521578
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
522579
breakpoint()
@@ -533,7 +590,7 @@ def __init__(
533590
dim_mult: Tuple[int] = [1,2,4,4],
534591
num_res_blocks: int = 2,
535592
attn_scales: List[float] = [],
536-
temporal_downsample: List[bool] = [False, True, True],
593+
temperal_downsample: List[bool] = [False, True, True],
537594
dropout: float = 0.0,
538595
latents_mean: List[float] = [
539596
-0.7571,-0.7089,-0.9113,0.1075,-0.1745,0.9653,-0.1517, 1.5508,
@@ -545,31 +602,59 @@ def __init__(
545602
],
546603
):
547604
self.z_dim = z_dim
548-
self.temporal_downsample = temporal_downsample
549-
self.temporal_upsample = temporal_downsample[::-1]
550-
551-
self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1)
552-
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1, rngs=rngs)
553-
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs)
605+
self.temperal_downsample = temperal_downsample
606+
self.temporal_upsample = temperal_downsample[::-1]
554607

555-
self.decoder = WanDecoder3d(
556-
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temporal_upsample, dropout
608+
self.encoder = WanEncoder3d(
609+
rngs=rngs,
610+
dim=base_dim,
611+
z_dim=z_dim * 2,
612+
dim_mult=dim_mult,
613+
num_res_blocks=num_res_blocks,
614+
attn_scales=attn_scales,
615+
temperal_downsample=temperal_downsample,
616+
dropout=dropout,
617+
)
618+
self.quant_conv = WanCausalConv3d(
619+
rngs=rngs,
620+
in_channels=z_dim * 2,
621+
out_channels=z_dim * 2,
622+
kernel_size=1
623+
)
624+
self.post_quant_conv = WanCausalConv3d(
625+
rngs=rngs,
626+
in_channels=z_dim,
627+
out_channels=z_dim,
628+
kernel_size=1,
557629
)
630+
631+
# self.decoder = WanDecoder3d(
632+
# rngs=rngs,
633+
# dim=base_dim,
634+
# z_dim=z_dim,
635+
# dim_mult=dim_mult,
636+
# num_res_blocks=num_res_blocks,
637+
# attn_scales=attn_scales,
638+
# temperal_upsample=self.temporal_upsample,
639+
# dropout=dropout
640+
# )
558641
self.clear_cache()
559642

560643
def clear_cache(self):
561644
""" Resets cache dictionaries and indices"""
562645
def _count_conv3d(module):
563646
count = 0
564-
node_types = nnx.graph.iter_graph(module, nnx.Module)
565-
for node in node_types:
566-
if isinstance(node.value, WanCausalConv3d):
647+
node_types = nnx.graph.iter_graph([module])
648+
for path, value in node_types:
649+
#breakpoint()
650+
if isinstance(value, WanCausalConv3d):
651+
print("value: ", value)
567652
count +=1
568653
return count
569654

570-
self._conv_num = _count_conv3d(self.decoder)
571-
self._conv_idx = [0]
572-
self._feat_map = [None] * self._conv_num
655+
# self._conv_num = _count_conv3d(self.decoder)
656+
# self._conv_idx = [0]
657+
# self._feat_map = [None] * self._conv_num
573658
# cache encode
574659
self._enc_conv_num = _count_conv3d(self.encoder)
575660
self._enc_conv_idx = [0]
@@ -581,7 +666,7 @@ def _encode(self, x: jax.Array):
581666
x = jnp.transpose(x, (0, 2, 3, 4, 1))
582667
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
583668

584-
self.clear_cache()
669+
#self.clear_cache()
585670

586671
t = x.shape[1]
587672
iter_ = 1 + (t - 1) // 4
@@ -590,7 +675,7 @@ def _encode(self, x: jax.Array):
590675
out = self.encoder(
591676
x[:, :1, :, :, :],
592677
feat_cache=self._enc_feat_map,
593-
feat_ids=self._enc_conv_idx
678+
feat_idx=self._enc_conv_idx
594679
)
595680
else:
596681
out_ = self.encoder(
@@ -600,11 +685,12 @@ def _encode(self, x: jax.Array):
600685
)
601686
out = jnp.concatenate([out, out_], axis=1)
602687

603-
enc = self.quant_conv(out)
604-
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
605-
enc = jnp.concatenate([mu, logvar], dim=1)
606-
self.clear_cache()
607-
return enc
688+
# enc = self.quant_conv(out)
689+
# mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
690+
# enc = jnp.concatenate([mu, logvar], dim=1)
691+
# self.clear_cache()
692+
# return enc
693+
return x
608694

609695
def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
610696
""" Encode video into latent distribution."""

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
WanUpsample,
3131
AutoencoderKLWan,
3232
WanEncoder3d,
33+
WanResidualBlock,
3334
WanRMS_norm,
3435
WanResample,
3536
ZeroPaddedConv2D
@@ -318,6 +319,45 @@ def test_3d_conv(self):
318319
output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache)
319320
assert output_with_larger_cache.shape == (1, 10, 32, 32, 16)
320321

322+
def test_wan_residual(self):
323+
key = jax.random.key(0)
324+
rngs = nnx.Rngs(key)
325+
# one test
326+
in_dim = out_dim = 96
327+
batch = 1
328+
t = 1
329+
height = 480
330+
width = 720
331+
dim = 96
332+
input_shape = (batch, t, height, width, dim)
333+
expected_output_shape = (batch, t, height, width, dim)
334+
335+
wan_residual_block = WanResidualBlock(
336+
in_dim=in_dim,
337+
out_dim=out_dim,
338+
rngs=rngs,
339+
)
340+
dummy_input = jnp.ones(input_shape)
341+
dummy_output = wan_residual_block(dummy_input)
342+
assert dummy_output.shape == expected_output_shape
343+
344+
# another test
345+
in_dim = 96
346+
out_dim = 196
347+
expected_output_shape = (batch, t, height, width, out_dim)
348+
349+
wan_residual_block = WanResidualBlock(
350+
in_dim=in_dim,
351+
out_dim=out_dim,
352+
rngs=rngs,
353+
)
354+
dummy_input = jnp.ones(input_shape)
355+
dummy_output = wan_residual_block(dummy_input)
356+
assert dummy_output.shape == expected_output_shape
357+
358+
359+
360+
321361
def test_wan_encode(self):
322362
key = jax.random.key(0)
323363
rngs = nnx.Rngs(key)
@@ -328,24 +368,34 @@ def test_wan_encode(self):
328368
attn_scales = []
329369
temperal_downsample = [False, True, True]
330370
nonlinearity = "silu"
331-
wan_encoder = WanEncoder3d(
332-
rngs=rngs,
333-
dim=dim,
334-
z_dim=z_dim,
335-
dim_mult=dim_mult,
336-
num_res_blocks=num_res_blocks,
337-
attn_scales=attn_scales,
338-
temperal_downsample=temperal_downsample,
339-
non_linearity=nonlinearity
371+
wan_vae = AutoencoderKLWan(
372+
rngs=rngs,
373+
base_dim=dim,
374+
z_dim=z_dim,
375+
dim_mult=dim_mult,
376+
num_res_blocks=num_res_blocks,
377+
attn_scales=attn_scales,
378+
temperal_downsample=temperal_downsample,
340379
)
380+
# wan_encoder = WanEncoder3d(
381+
# rngs=rngs,
382+
# dim=dim,
383+
# z_dim=z_dim,
384+
# dim_mult=dim_mult,
385+
# num_res_blocks=num_res_blocks,
386+
# attn_scales=attn_scales,
387+
# temperal_downsample=temperal_downsample,
388+
# non_linearity=nonlinearity
389+
# )
341390
batch = 1
342391
channels = 3
343392
t = 49
344393
height = 480
345394
width = 720
346395
input_shape = (batch, channels, t, height, width)
347396
input = jnp.ones(input_shape)
348-
output = wan_encoder(input)
397+
output = wan_vae.encode(input)
398+
breakpoint()
349399

350400

351401

0 commit comments

Comments
 (0)