Skip to content

Commit aeabe27

Browse files
wip - test for vae encoder.
1 parent 4e443b8 commit aeabe27

2 files changed

Lines changed: 256 additions & 28 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,13 @@ def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> T
5555
class WanCausalConv3d(nnx.Module):
5656
def __init__(
5757
self,
58+
rngs: nnx.Rngs, # rngs are required for initializing parameters,
5859
in_channels: int,
5960
out_channels: int,
6061
kernel_size: Union[int, Tuple[int, int, int]],
61-
*, # Mark subsequent arguments as keyword-only
6262
stride: Union[int, Tuple[int, int, int]] = 1,
6363
padding: Union[int, Tuple[int, int, int]] = 0,
6464
use_bias: bool = True,
65-
rngs: nnx.Rngs, # rngs are required for initializing parameters,
6665
flash_min_seq_length: int = 4096,
6766
flash_block_sizes: BlockSizes = None,
6867
mesh: jax.sharding.Mesh = None,
@@ -267,7 +266,13 @@ def __init__(
267266
rngs=rngs,
268267
)
269268
)
270-
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs)
269+
self.time_conv = WanCausalConv3d(
270+
rngs=rngs,
271+
in_channels=dim,
272+
out_channels=dim * 2,
273+
kernel_size=(3, 1, 1),
274+
padding=(1, 0, 0),
275+
)
271276
elif mode == "downsample2d":
272277
# TODO - do I need to transpose?
273278
self.resample = ZeroPaddedConv2D(
@@ -288,6 +293,15 @@ def __init__(
288293
self.resample = Identity()
289294

290295
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]) -> jax.Array:
296+
# Input x: (N, D, H, W, C), assume C = self.dim
297+
n, d, h, w, c = x.shape
298+
assert c == self.dim
299+
300+
x = x.reshape(n*d,h,w,c)
301+
x = self.resample(x)
302+
h_new, w_new, c_new = x.shape[1:]
303+
x = x.reshape(n, d, h_new, w_new, c_new)
304+
291305
return x
292306

293307
class WanResidualBlock(nnx.Module):
@@ -382,7 +396,13 @@ def __init__(
382396
scale = 1.0
383397

384398
# init block
385-
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1, rngs=rngs)
399+
self.conv_in = WanCausalConv3d(
400+
in_channels=3,
401+
out_channels=dims[0],
402+
kernel_size=3,
403+
padding=1,
404+
rngs=rngs
405+
)
386406

387407
# downsample blocks
388408
self.down_blocks = []
@@ -400,11 +420,23 @@ def __init__(
400420
self.down_blocks.append(WanResample(out_dim, mode=mode, rngs=rngs))
401421

402422
# middle_blocks
403-
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1, rngs=rngs)
423+
self.mid_block = WanMidBlock(
424+
dim=out_dim,
425+
rngs=rngs,
426+
dropout=dropout,
427+
non_linearity=non_linearity,
428+
num_layers=1,
429+
)
404430

405431
# output blocks
406432
self.norm_out = WanRMS_norm(out_dim, images=False, rngs=rngs)
407-
self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=z_dim, kernel_size=3, padding=1)
433+
self.conv_out = WanCausalConv3d(
434+
rngs=rngs,
435+
in_channels=out_dim,
436+
out_channels=z_dim,
437+
kernel_size=3,
438+
padding=1
439+
)
408440

409441
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
410442
return x
@@ -487,6 +519,9 @@ def __init__(
487519
self.conv_out = WanCausalConv3d(in_channels=out_dim, out_channels=3, kernel_size=3, padding=1, rngs=rngs)
488520

489521
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
522+
breakpoint()
523+
x = self.conv_in(x)
524+
breakpoint()
490525
return x
491526

492527
class AutoencoderKLWan(nnx.Module, FlaxModelMixin, ConfigMixin):
@@ -514,6 +549,7 @@ def __init__(
514549
self.temporal_upsample = temporal_downsample[::-1]
515550

516551
self.encoder = WanEncoder3d(z_dim * 2, z_dim * 2, 1)
552+
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1, rngs=rngs)
517553
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1, rngs=rngs)
518554

519555
self.decoder = WanDecoder3d(
@@ -539,10 +575,42 @@ def _count_conv3d(module):
539575
self._enc_conv_idx = [0]
540576
self._enc_feat_map = [None] * self._enc_conv_num
541577

578+
def _encode(self, x: jax.Array):
579+
if x.shape[-1] != 3:
580+
# reshape channel last for JAX
581+
x = jnp.transpose(x, (0, 2, 3, 4, 1))
582+
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
583+
584+
self.clear_cache()
585+
586+
t = x.shape[1]
587+
iter_ = 1 + (t - 1) // 4
588+
for i in range(iter_):
589+
if i == 0:
590+
out = self.encoder(
591+
x[:, :1, :, :, :],
592+
feat_cache=self._enc_feat_map,
593+
feat_ids=self._enc_conv_idx
594+
)
595+
else:
596+
out_ = self.encoder(
597+
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
598+
feat_cache=self._enc_feat_map,
599+
feat_idx=self._enc_conv_idx
600+
)
601+
out = jnp.concatenate([out, out_], axis=1)
602+
603+
enc = self.quant_conv(out)
604+
mu, logvar = enc[:, :, :, :, : self.z_dim], enc[:, :, :, :, self.z_dim :]
605+
enc = jnp.concatenate([mu, logvar], dim=1)
606+
self.clear_cache()
607+
return enc
542608

543609
def encode(self, x: jax.Array, return_dict: bool = True) -> Union[FlaxAutoencoderKLOutput, Tuple[FlaxDiagonalGaussianDistribution]]:
544610
""" Encode video into latent distribution."""
545-
if x.shape[-1] != 3:
546-
raise ValueError(f"Expected input shape (N, D, H, W, 3), got {x.shape}")
547-
548-
self.clear_cache()
611+
h = self._encode(x)
612+
posterior = FlaxDiagonalGaussianDistribution(h)
613+
if not return_dict:
614+
return (posterior, )
615+
return FlaxAutoencoderKLOutput(latent_dict=posterior)
616+

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 178 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
"""
1616

1717
import os
18+
import torch
19+
import torch.nn as nn
20+
import torch.nn.functional as F
1821
import jax
1922
import jax.numpy as jnp
2023
from flax import nnx
@@ -26,27 +29,15 @@
2629
WanCausalConv3d,
2730
WanUpsample,
2831
AutoencoderKLWan,
32+
WanEncoder3d,
2933
WanRMS_norm,
34+
WanResample,
3035
ZeroPaddedConv2D
3136
)
3237

33-
class WanVaeTest(unittest.TestCase):
34-
def setUp(self):
35-
WanVaeTest.dummy_data = {}
36-
37-
# def test_clear_cache(self):
38-
# key = jax.random.key(0)
39-
# rngs = nnx.Rngs(key)
40-
# wan_vae = AutoencoderKLWan(rngs=rngs)
41-
# wan_vae.clear_cache()
38+
CACHE_T = 2
4239

43-
def test_wanrms_norm(self):
44-
"""Test against the Pytorch implementation"""
45-
import torch
46-
import torch.nn as nn
47-
import torch.nn.functional as F
48-
49-
class TorchWanRMS_norm(nn.Module):
40+
class TorchWanRMS_norm(nn.Module):
5041
r"""
5142
A custom RMS normalization layer.
5243
@@ -70,6 +61,103 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
7061

7162
def forward(self, x):
7263
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
64+
65+
class TorchWanResample(nn.Module):
66+
r"""
67+
A custom resampling module for 2D and 3D data.
68+
69+
Args:
70+
dim (int): The number of input/output channels.
71+
mode (str): The resampling mode. Must be one of:
72+
- 'none': No resampling (identity operation).
73+
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
74+
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
75+
- 'downsample2d': 2D downsampling with zero-padding and convolution.
76+
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
77+
"""
78+
79+
def __init__(self, dim: int, mode: str) -> None:
80+
super().__init__()
81+
self.dim = dim
82+
self.mode = mode
83+
84+
# layers
85+
if mode == "upsample2d":
86+
self.resample = nn.Sequential(
87+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
88+
)
89+
elif mode == "upsample3d":
90+
self.resample = nn.Sequential(
91+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
92+
)
93+
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
94+
95+
elif mode == "downsample2d":
96+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
97+
elif mode == "downsample3d":
98+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
99+
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
100+
101+
else:
102+
self.resample = nn.Identity()
103+
104+
def forward(self, x, feat_cache=None, feat_idx=[0]):
105+
b, c, t, h, w = x.size()
106+
if self.mode == "upsample3d":
107+
if feat_cache is not None:
108+
idx = feat_idx[0]
109+
if feat_cache[idx] is None:
110+
feat_cache[idx] = "Rep"
111+
feat_idx[0] += 1
112+
else:
113+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
114+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
115+
# cache last frame of last two chunk
116+
cache_x = torch.cat(
117+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
118+
)
119+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
120+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
121+
if feat_cache[idx] == "Rep":
122+
x = self.time_conv(x)
123+
else:
124+
x = self.time_conv(x, feat_cache[idx])
125+
feat_cache[idx] = cache_x
126+
feat_idx[0] += 1
127+
128+
x = x.reshape(b, 2, c, t, h, w)
129+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
130+
x = x.reshape(b, c, t * 2, h, w)
131+
t = x.shape[2]
132+
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
133+
x = self.resample(x)
134+
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
135+
136+
if self.mode == "downsample3d":
137+
if feat_cache is not None:
138+
idx = feat_idx[0]
139+
if feat_cache[idx] is None:
140+
feat_cache[idx] = x.clone()
141+
feat_idx[0] += 1
142+
else:
143+
cache_x = x[:, :, -1:, :, :].clone()
144+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
145+
feat_cache[idx] = cache_x
146+
feat_idx[0] += 1
147+
return x
148+
149+
class WanVaeTest(unittest.TestCase):
150+
def setUp(self):
151+
WanVaeTest.dummy_data = {}
152+
153+
# def test_clear_cache(self):
154+
# key = jax.random.key(0)
155+
# rngs = nnx.Rngs(key)
156+
# wan_vae = AutoencoderKLWan(rngs=rngs)
157+
# wan_vae.clear_cache()
158+
159+
def test_wanrms_norm(self):
160+
"""Test against the Pytorch implementation"""
73161

74162
# --- Test Case 1: images == True ---
75163
dim = 96
@@ -103,8 +191,6 @@ def forward(self, x):
103191
assert np.allclose(output_np, torch_output_np) == True
104192

105193
def test_zero_padded_conv(self):
106-
import torch
107-
import torch.nn as nn
108194

109195
key = jax.random.key(0)
110196
rngs = nnx.Rngs(key)
@@ -148,6 +234,49 @@ def test_wan_upsample(self):
148234
# --- Test Case 1: depth == 1 ---
149235
output = upsample(dummy_input)
150236
assert output.shape == (1, 1, 64, 64, 3)
237+
238+
def test_wan_resample(self):
239+
# TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity
240+
key = jax.random.key(0)
241+
rngs = nnx.Rngs(key)
242+
243+
# --- Test Case 1: downsample2d ---
244+
batch = 1
245+
dim = 96
246+
t = 1
247+
h = 480
248+
w = 720
249+
mode="downsample2d"
250+
input_shape = (batch, dim, t, h, w)
251+
expected_output_shape = (1, dim, 1, 240, 360)
252+
# output dim should be (1, 96, 1, 480, 720)
253+
dummy_input = torch.ones(input_shape)
254+
torch_wan_resample = TorchWanResample(
255+
dim=dim,
256+
mode=mode
257+
)
258+
torch_output = torch_wan_resample(dummy_input)
259+
assert torch_output.shape == (batch, dim, t, h // 2, w //2)
260+
261+
wan_resample = WanResample(
262+
dim,
263+
mode=mode,
264+
rngs=rngs
265+
)
266+
# channels is always last here
267+
input_shape = (batch, t, h, w, dim)
268+
dummy_input = jnp.ones(input_shape)
269+
output = wan_resample(dummy_input)
270+
assert output.shape == (batch, t, h//2, h//2, dim)
271+
breakpoint()
272+
273+
# --- Test Case 1: downsample3d ---
274+
dim = 192
275+
input_shape = (1, dim, 1, 240, 360)
276+
torch_wan_resample = WanResample(
277+
dim=dim,
278+
mode="downsample3d"
279+
)
151280

152281
def test_3d_conv(self):
153282
key = jax.random.key(0)
@@ -189,5 +318,36 @@ def test_3d_conv(self):
189318
output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache)
190319
assert output_with_larger_cache.shape == (1, 10, 32, 32, 16)
191320

321+
def test_wan_encode(self):
322+
key = jax.random.key(0)
323+
rngs = nnx.Rngs(key)
324+
dim = 96
325+
z_dim = 32
326+
dim_mult = [1, 2, 4, 4]
327+
num_res_blocks = 2
328+
attn_scales = []
329+
temperal_downsample = [False, True, True]
330+
nonlinearity = "silu"
331+
wan_encoder = WanEncoder3d(
332+
rngs=rngs,
333+
dim=dim,
334+
z_dim=z_dim,
335+
dim_mult=dim_mult,
336+
num_res_blocks=num_res_blocks,
337+
attn_scales=attn_scales,
338+
temperal_downsample=temperal_downsample,
339+
non_linearity=nonlinearity
340+
)
341+
batch = 1
342+
channels = 3
343+
t = 49
344+
height = 480
345+
width = 720
346+
input_shape = (batch, channels, t, height, width)
347+
input = jnp.ones(input_shape)
348+
output = wan_encoder(input)
349+
350+
351+
192352
if __name__ == "__main__":
193353
absltest.main()

0 commit comments

Comments
 (0)