Skip to content

Commit cf68754

Browse files
adds decoder and checks matching resolutions.
1 parent efe8528 commit cf68754

2 files changed

Lines changed: 174 additions & 45 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 141 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
from ...configuration_utils import ConfigMixin, flax_register_to_config
2323
from ..modeling_flax_utils import FlaxModelMixin
2424
from ... import common_types
25-
from ..vae_flax import FlaxAutoencoderKLOutput, FlaxDiagonalGaussianDistribution
25+
from ..vae_flax import (
26+
FlaxAutoencoderKLOutput,
27+
FlaxDiagonalGaussianDistribution,
28+
FlaxDecoderOutput
29+
)
2630

2731
BlockSizes = common_types.BlockSizes
2832

@@ -82,7 +86,7 @@ def __init__(
8286
(0, 0) # Channel dimension - no padding
8387
)
8488

85-
# Store the amount of padding needed *before* the depth dimension for caching logoic
89+
# Store the amount of padding needed *before* the depth dimension for caching logic
8690
self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0]
8791

8892
self.conv = nnx.Conv(
@@ -103,7 +107,6 @@ def __call__(self, x: jax.Array, cache_x: Optional[jax.Array] = None) -> jax.Arr
103107
# Ensure cache has same spatial/channel dims, potentially different depth
104108
assert cache_x.shape[0] == x.shape[0] and \
105109
cache_x.shape[2:] == x.shape[2:], "Cache spatial/channel dims mismatch"
106-
107110
cache_len = cache_x.shape[1]
108111
x = jnp.concatenate([cache_x, x], axis=1) # Concat along depth (D)
109112

@@ -166,24 +169,13 @@ def __init__(self, scale_factor: Tuple[float, float], method: str = 'nearest'):
166169
def __call__(self, x: jax.Array) -> jax.Array:
167170
input_dtype = x.dtype
168171
in_shape = x.shape
169-
is_3d = len(in_shape) == 5
170-
n, d, h, w, c = in_shape if is_3d else(in_shape[0], 1, in_shape[1], in_shape[2], in_shape[3])
171-
172+
assert len(in_shape) == 4, "This module only takes tensors with shape of 4."
173+
n, h, w, c = in_shape
172174
target_h = int(h * self.scale_factor[0])
173175
target_w = int(w * self.scale_factor[1])
174-
175-
# jax.image.resize expects (..., H, W, C)
176-
if is_3d:
177-
x_reshaped = x.reshape(n * d, h, w, c)
178-
out_reshaped = jax.image.resize(x_reshaped.astype(jnp.float32),
179-
(n * d, target_h, target_w, c),
180-
method=self.method)
181-
out = out_reshaped.reshape(n, d, target_h, target_w, c)
182-
else: # Asumming (N, H, W, C)
183-
out = jax.image.resize(x.astype(jnp.float32),
184-
(n, target_h, target_w, c),
185-
method=self.method)
186-
176+
out = jax.image.resize(x.astype(jnp.float32),
177+
(n, target_h, target_w, c),
178+
method=self.method)
187179
return out.astype(input_dtype)
188180

189181
class Identity(nnx.Module):
@@ -256,7 +248,7 @@ def __init__(
256248
)
257249
elif mode == "upsample3d":
258250
self.resample = nnx.Sequential(
259-
WanUpsample(scale_factor=(2.0, 2.0), method="nearest"),
251+
WanUpsample(scale_factor=(2.0, 2.0, 2.0), method="nearest"),
260252
nnx.Conv(
261253
dim,
262254
dim // 2,
@@ -305,6 +297,29 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
305297
n, d, h, w, c = x.shape
306298
assert c == self.dim
307299

300+
if self.mode == "upsample3d":
301+
if feat_cache is not None:
302+
idx = feat_idx[0]
303+
if feat_cache[idx] is None:
304+
feat_cache[idx] = "Rep"
305+
feat_idx[0] += 1
306+
else:
307+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
308+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
309+
# cache last frame of last two chunk
310+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
311+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
312+
cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], dim=1)
313+
if feat_cache[idx] == "Rep":
314+
x = self.time_conv(x)
315+
else:
316+
x = self.time_conv(x, feat_cache[idx])
317+
feat_cache[idx] = cache_x
318+
feat_idx[0] += 1
319+
x = x.reshape(n, 2, d, h, w, c)
320+
x = jnp.stack([x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]], axis=2)
321+
x = x.reshape(n, d*2, h, w, c)
322+
d = x.shape[1]
308323
x = x.reshape(n*d,h,w,c)
309324
x = self.resample(x)
310325
h_new, w_new, c_new = x.shape[1:]
@@ -371,7 +386,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
371386
if feat_cache is not None:
372387
idx = feat_idx[0]
373388
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
374-
if cache_x.shape[1] <2 and feat_cache[idx] is not None:
389+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
375390
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
376391

377392
x = self.conv1(x, feat_cache[idx])
@@ -387,7 +402,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
387402
if feat_cache is not None:
388403
idx = feat_idx[0]
389404
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
390-
if cache_x.shape[1] <2 and feat_cache[idx] is not None:
405+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
391406
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
392407
x = self.conv2(x, feat_cache[idx])
393408
feat_cache[idx] = cache_x
@@ -458,7 +473,7 @@ def __init__(
458473
attentions = []
459474
for _ in range(num_layers):
460475
attentions.append(WanAttentionBlock(dim=dim, rngs=rngs))
461-
resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs,dropout=dropout, non_linearity=non_linearity))
476+
resnets.append(WanResidualBlock(in_dim=dim, out_dim=dim, rngs=rngs, dropout=dropout, non_linearity=non_linearity))
462477
self.attentions = attentions
463478
self.resnets = resnets
464479

@@ -467,7 +482,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
467482
for attn, resnet in zip(self.attentions, self.resnets[1:]):
468483
if attn is not None:
469484
x = attn(x)
470-
x = resnet(x)
485+
x = resnet(x, feat_cache, feat_idx)
471486
return x
472487

473488
class WanUpBlock(nnx.Module):
@@ -482,19 +497,31 @@ def __init__(
482497
non_linearity: str = "silu"
483498
):
484499
# Create layers list
485-
self.resnets = []
500+
resnets = []
486501
# Add residual blocks and attention if needed
487502
current_dim = in_dim
488503
for _ in range(num_res_blocks + 1):
489-
self.resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs))
504+
resnets.append(WanResidualBlock(in_dim=current_dim, out_dim=out_dim, dropout=dropout, non_linearity=non_linearity, rngs=rngs))
490505
current_dim = out_dim
506+
self.resnets = resnets
491507

492508
# Add upsampling layer if needed.
493509
self.upsamplers = None
494510
if upsample_mode is not None:
495-
self.upsamplers = WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)
511+
self.upsamplers = [WanResample(dim=out_dim, mode=upsample_mode, rngs=rngs)]
496512

497513
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
514+
for resnet in self.resnets:
515+
if feat_cache is not None:
516+
x = resnet(x, feat_cache, feat_idx)
517+
else:
518+
x = resnet(x)
519+
520+
if self.upsamplers is not None:
521+
if feat_cache is not None:
522+
x = self.upsamplers[0](x, feat_cache, feat_idx)
523+
else:
524+
x = self.upsamplers[0](x)
498525
return x
499526

500527
class WanEncoder3d(nnx.Module):
@@ -655,7 +682,13 @@ def __init__(
655682
)
656683

657684
# middle_blocks
658-
self.mid_block = WanMidBlock(dim=dims[0], rngs=rngs, dropout=dropout, non_linearity=non_linearity, num_layers=1)
685+
self.mid_block = WanMidBlock(
686+
dim=dims[0],
687+
rngs=rngs,
688+
dropout=dropout,
689+
non_linearity=non_linearity,
690+
num_layers=1
691+
)
659692

660693
# upsample blocks
661694
self.up_blocks = []
@@ -668,7 +701,6 @@ def __init__(
668701
upsample_mode = None
669702
if i != len(dim_mult) - 1:
670703
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
671-
672704
# Crete and add the upsampling block
673705
up_block = WanUpBlock(
674706
in_dim=in_dim,
@@ -686,8 +718,7 @@ def __init__(
686718
scale *=2.0
687719

688720
# output blocks
689-
self.norm_out = nnx.RMSNorm(num_features=out_dim, )
690-
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs)
721+
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False)
691722
self.conv_out = WanCausalConv3d(
692723
rngs=rngs,
693724
in_channels=out_dim,
@@ -697,7 +728,39 @@ def __init__(
697728
)
698729

699730
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
700-
x = self.conv_in(x)
731+
if feat_cache is not None:
732+
idx = feat_idx[0]
733+
cache_x = jnp.copy(x[:, -CACHE_T: , :, :, :])
734+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
735+
# cache last frame of the last two chunk
736+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
737+
x = self.conv_in(x, feat_cache[idx])
738+
feat_cache[idx] = cache_x
739+
feat_idx[0] += 1
740+
else:
741+
x = self.conv_in(x)
742+
743+
## middle
744+
x = self.mid_block(x, feat_cache, feat_idx)
745+
746+
## upsamples
747+
for up_block in self.up_blocks:
748+
x = up_block(x, feat_cache, feat_idx)
749+
750+
## head
751+
x = self.norm_out(x)
752+
x = self.nonlinearity(x)
753+
if feat_cache is not None:
754+
idx = feat_idx[0]
755+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
756+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
757+
# cache last frame of the last two chunk
758+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
759+
x = self.conv_out(x, feat_cache[idx])
760+
feat_cache[idx] = cache_x
761+
feat_idx[0] += 1
762+
else:
763+
x = self.conv_out(x)
701764
return x
702765

703766
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -723,6 +786,8 @@ def __init__(
723786
self.z_dim = z_dim
724787
self.temperal_downsample = temperal_downsample
725788
self.temporal_upsample = temperal_downsample[::-1]
789+
self.latents_mean = latents_mean
790+
self.latents_std = latents_std
726791

727792
self.encoder = WanEncoder3d(
728793
rngs=rngs,
@@ -747,16 +812,16 @@ def __init__(
747812
kernel_size=1,
748813
)
749814

750-
# self.decoder = WanDecoder3d(
751-
# rngs=rngs,
752-
# dim=base_dim,
753-
# z_dim=z_dim,
754-
# dim_mult=dim_mult,
755-
# num_res_blocks=num_res_blocks,
756-
# attn_scales=attn_scales,
757-
# temperal_upsample=self.temporal_upsample,
758-
# dropout=dropout
759-
# )
815+
self.decoder = WanDecoder3d(
816+
rngs=rngs,
817+
dim=base_dim,
818+
z_dim=z_dim,
819+
dim_mult=dim_mult,
820+
num_res_blocks=num_res_blocks,
821+
attn_scales=attn_scales,
822+
temperal_upsample=self.temporal_upsample,
823+
dropout=dropout
824+
)
760825
self.clear_cache()
761826

762827
def clear_cache(self):
@@ -769,9 +834,9 @@ def _count_conv3d(module):
769834
count +=1
770835
return count
771836

772-
# self._conv_num = _count_conv3d(self.decoder)
773-
# self._conv_idx = [0]
774-
# self._feat_map = [None] * self._conv_num
837+
self._conv_num = _count_conv3d(self.decoder)
838+
self._conv_idx = [0]
839+
self._feat_map = [None] * self._conv_num
775840
# cache encode
776841
self._enc_conv_num = _count_conv3d(self.encoder)
777842
self._enc_conv_idx = [0]
@@ -817,4 +882,35 @@ def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencode
817882
if not return_dict:
818883
return (posterior, )
819884
return FlaxAutoencoderKLOutput(latent_dist=posterior)
885+
886+
def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]:
887+
self.clear_cache()
888+
iter_ = z.shape[1]
889+
x = self.post_quant_conv(z)
890+
for i in range(iter_):
891+
self._conv_idx = [0]
892+
if i == 0:
893+
out = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
894+
else:
895+
out_ = self.decoder(x[:,i : i + 1, :, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
896+
897+
out = jnp.concatenate([out, out_], axis=1)
898+
899+
out = jnp.clip(out, a_min=-1.0, a_max=1.0)
900+
self.clear_cache()
901+
if not return_dict:
902+
return (out, )
903+
904+
return FlaxDecoderOutput(sample=out)
905+
906+
def decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOutput, jax.Array]:
907+
if z.shape[-1] != self.z_dim:
908+
# reshape channel last for JAX
909+
x = jnp.transpose(x, (0, 2, 3, 4, 1))
910+
assert x.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {x.shape}"
911+
decoded = self._decode(z).sample
912+
if not return_dict:
913+
return (decoded,)
914+
return FlaxDecoderOutput(sample=decoded)
915+
820916

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,39 @@ def test_wan_midblock(self):
390390
output = wan_midblock(dummy_input)
391391
assert output.shape == input_shape
392392

393+
def test_wan_decode(self):
394+
key = jax.random.key(0)
395+
rngs = nnx.Rngs(key)
396+
dim = 96
397+
z_dim = 16
398+
dim_mult = [1, 2, 4, 4]
399+
num_res_blocks = 2
400+
attn_scales = []
401+
temperal_downsample = [False, True, True]
402+
wan_vae = AutoencoderKLWan(
403+
rngs=rngs,
404+
base_dim=dim,
405+
z_dim=z_dim,
406+
dim_mult=dim_mult,
407+
num_res_blocks=num_res_blocks,
408+
attn_scales=attn_scales,
409+
temperal_downsample=temperal_downsample,
410+
)
411+
412+
batch = 1
413+
t = 13
414+
channels = 16
415+
height = 60
416+
width = 90
417+
input_shape = (batch, t, height, width, channels)
418+
input = jnp.ones(input_shape)
419+
420+
latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim)
421+
latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim)
422+
input = input / latents_std + latents_mean
423+
dummy_output = wan_vae.decode(input)
424+
assert dummy_output.sample.shape == (batch, 49, 480, 720, 3)
425+
393426
def test_wan_encode(self):
394427
key = jax.random.key(0)
395428
rngs = nnx.Rngs(key)

0 commit comments

Comments
 (0)