Skip to content

Commit a5e1e95

Browse files
finishes vae encoder with matching shapes
1 parent 4325325 commit a5e1e95

2 files changed

Lines changed: 74 additions & 30 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,14 @@ def __init__(
289289
kernel_size=(1, 3, 3),
290290
stride=(1, 2, 2)
291291
)
292+
self.time_conv = WanCausalConv3d(
293+
rngs=rngs,
294+
in_channels = dim,
295+
out_channels = dim,
296+
kernel_size=(3, 1, 1),
297+
stride=(2, 1, 1),
298+
padding= (0, 0, 0)
299+
)
292300
else:
293301
self.resample = Identity()
294302

@@ -302,6 +310,18 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
302310
h_new, w_new, c_new = x.shape[1:]
303311
x = x.reshape(n, d, h_new, w_new, c_new)
304312

313+
if self.mode == "downsample3d":
314+
if feat_cache is not None:
315+
idx = feat_idx[0]
316+
if feat_cache[idx] is None:
317+
feat_cache[idx] = jnp.copy(x)
318+
feat_idx[0] +=1
319+
else:
320+
cache_x = jnp.copy(x[:, -1:, :, :, :])
321+
x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
322+
feat_cache[idx] = cache_x
323+
feat_idx[0] += 1
324+
305325
return x
306326

307327
class WanResidualBlock(nnx.Module):
@@ -343,7 +363,6 @@ def __init__(
343363

344364
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
345365
# Apply shortcut connection
346-
#breakpoint()
347366
h = self.conv_shortcut(x)
348367

349368
x = self.norm1(x)
@@ -505,7 +524,8 @@ def __init__(
505524
if i != len(dim_mult) - 1:
506525
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
507526
self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs))
508-
527+
scale /= 2.0
528+
509529
# middle_blocks
510530
self.mid_block = WanMidBlock(
511531
dim=out_dim,
@@ -516,7 +536,12 @@ def __init__(
516536
)
517537

518538
# output blocks
519-
self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs)
539+
self.norm_out = WanRMS_norm(
540+
out_dim,
541+
channel_first=False,
542+
images=False,
543+
rngs=rngs
544+
)
520545
self.conv_out = WanCausalConv3d(
521546
rngs=rngs,
522547
in_channels=out_dim,
@@ -526,14 +551,39 @@ def __init__(
526551
)
527552

528553
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
529-
# (1, 1, 480, 720, 3)
530-
x = self.conv_in(x)
554+
if feat_cache is not None:
555+
idx = feat_idx[0]
556+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :])
557+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
558+
# cache last frame of the last two chunk
559+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
560+
x = self.conv_in(x, feat_cache[idx])
561+
feat_cache[idx] = cache_x
562+
feat_idx[0] +=1
563+
else:
564+
x = self.conv_in(x)
531565
# (1, 1, 480, 720, 96)
532566
for layer in self.down_blocks:
533-
x = layer(x)
567+
if feat_cache is not None:
568+
x = layer(x, feat_cache, feat_idx)
569+
else:
570+
x = layer(x)
534571

535-
x = self.mid_block(x)
536-
breakpoint()
572+
x = self.mid_block(x, feat_cache, feat_idx)
573+
574+
x = self.norm_out(x)
575+
x = self.nonlinearity(x)
576+
if feat_cache is not None:
577+
idx = feat_idx[0]
578+
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
579+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None:
580+
# cache last frame of last two chunk
581+
cache_x = jnp.concatenate([jnp.expand_dims(feat_cache[idx][:, -1, :, :, :], axis=1), cache_x], axis=1)
582+
x = self.conv_out(x, feat_cache[idx])
583+
feat_cache[idx] = cache_x
584+
feat_idx[0] +=1
585+
else:
586+
x = self.conv_out(x)
537587
return x
538588

539589
class WanDecoder3d(nnx.Module):
@@ -626,9 +676,7 @@ def __init__(
626676
)
627677

628678
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
629-
breakpoint()
630679
x = self.conv_in(x)
631-
breakpoint()
632680
return x
633681

634682
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -696,9 +744,7 @@ def _count_conv3d(module):
696744
count = 0
697745
node_types = nnx.graph.iter_graph([module])
698746
for path, value in node_types:
699-
#breakpoint()
700747
if isinstance(value, WanCausalConv3d):
701-
print("value: ", value)
702748
count +=1
703749
return count
704750

@@ -711,6 +757,7 @@ def _count_conv3d(module):
711757
self._enc_feat_map = [None] * self._enc_conv_num
712758

713759
def _encode(self, x: jax.Array):
760+
self.clear_cache()
714761
if x.shape[-1] != 3:
715762
# reshape channel last for JAX
716763
x = jnp.transpose(x, (0, 2, 3, 4, 1))
@@ -721,6 +768,7 @@ def _encode(self, x: jax.Array):
721768
t = x.shape[1]
722769
iter_ = 1 + (t - 1) // 4
723770
for i in range(iter_):
771+
self._enc_conv_idx = [0]
724772
if i == 0:
725773
out = self.encoder(
726774
x[:, :1, :, :, :],
@@ -729,24 +777,23 @@ def _encode(self, x: jax.Array):
729777
)
730778
else:
731779
out_ = self.encoder(
732-
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
780+
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :],
733781
feat_cache=self._enc_feat_map,
734782
feat_idx=self._enc_conv_idx
735783
)
736784
out = jnp.concatenate([out, out_], axis=1)
737-
738-
# enc = self.quant_conv(out)
739-
# mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
740-
# enc = jnp.concatenate([mu, logvar], dim=1)
741-
# self.clear_cache()
785+
enc = self.quant_conv(out)
786+
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
787+
enc = jnp.concatenate([mu, logvar], axis=-1)
788+
self.clear_cache()
742789
# return enc
743-
return x
790+
return enc
744791

745792
def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
746793
""" Encode video into latent distribution."""
747794
h = self._encode(x)
748795
posterior = FlaxDiagonalGaussianDistribution(h)
749796
if not return_dict:
750797
return (posterior, )
751-
return FlaxAutoencoderKLOutput(latent_dict=posterior)
798+
return FlaxAutoencoderKLOutput(latent_dist=posterior)
752799

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ class WanVaeTest(unittest.TestCase):
153153
def setUp(self):
154154
WanVaeTest.dummy_data = {}
155155

156-
# def test_clear_cache(self):
157-
# key = jax.random.key(0)
158-
# rngs = nnx.Rngs(key)
159-
# wan_vae = AutoencoderKLWan(rngs=rngs)
160-
# wan_vae.clear_cache()
156+
def test_clear_cache(self):
157+
key = jax.random.key(0)
158+
rngs = nnx.Rngs(key)
159+
wan_vae = AutoencoderKLWan(rngs=rngs)
160+
wan_vae.clear_cache()
161161

162162
def test_wanrms_norm(self):
163163
"""Test against the Pytorch implementation"""
@@ -394,12 +394,11 @@ def test_wan_encode(self):
394394
key = jax.random.key(0)
395395
rngs = nnx.Rngs(key)
396396
dim = 96
397-
z_dim = 32
397+
z_dim = 16
398398
dim_mult = [1, 2, 4, 4]
399399
num_res_blocks = 2
400400
attn_scales = []
401401
temperal_downsample = [False, True, True]
402-
nonlinearity = "silu"
403402
wan_vae = AutoencoderKLWan(
404403
rngs=rngs,
405404
base_dim=dim,
@@ -417,9 +416,7 @@ def test_wan_encode(self):
417416
input_shape = (batch, channels, t, height, width)
418417
input = jnp.ones(input_shape)
419418
output = wan_vae.encode(input)
420-
breakpoint()
421-
422-
419+
assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16)
423420

424421
if __name__ == "__main__":
425422
absltest.main()

0 commit comments

Comments
 (0)