Skip to content

Commit 4e443b8

Browse files
add unit tests to wan vae padded conv.
1 parent cc11bb1 commit 4e443b8

3 files changed

Lines changed: 66 additions & 16 deletions

File tree

src/maxdiffusion/generate_wan.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def run(config):
214214
vae=None,
215215
transformer=None
216216
)
217-
breakpoint()
218217

219218
#wan_transformer = WanModel(rngs=nnx.Rngs(config.seed))
220219

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,16 @@ def __call__(self, x):
192192
return x
193193

194194
class ZeroPaddedConv2D(nnx.Module):
195+
"""
196+
Module for adding padding before conv.
197+
Currently it does not add any padding.
198+
"""
195199
def __init__(
196200
self,
197201
dim: int,
198202
rngs: nnx.Rngs,
203+
kernel_size: Union[int, Tuple[int, int, int]],
204+
stride: Union[int, Tuple[int, int, int]] = 1,
199205
flash_min_seq_length: int = 4096,
200206
flash_block_sizes: BlockSizes = None,
201207
mesh: jax.sharding.Mesh = None,
@@ -204,18 +210,18 @@ def __init__(
204210
precision: jax.lax.Precision = None,
205211
attention: str = "dot_product",
206212
):
213+
kernel_size = _canonicalize_tuple(kernel_size, 3, 'kernel_size')
214+
stride = _canonicalize_tuple(stride, 3, 'stride')
207215
self.conv = nnx.Conv(
208216
dim,
209217
dim,
210-
kernel_size=(1, 3, 3),
211-
padding='SAME',
218+
kernel_size=kernel_size,
219+
strides=stride,
212220
use_bias=True,
213221
rngs=rngs
214222
)
215223

216224
def __call__(self, x):
217-
# This pad assumes (B, C, H, W)
218-
x = jax.lax.pad(x, 0.0, [(0, 0, 0), (0, 0, 0), (0, 1, 0), (0, 1, 0)])
219225
return self.conv(x)
220226

221227

@@ -263,9 +269,21 @@ def __init__(
263269
)
264270
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0), rngs=rngs)
265271
elif mode == "downsample2d":
266-
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs)
272+
# TODO - do I need to transpose?
273+
self.resample = ZeroPaddedConv2D(
274+
dim=dim,
275+
rngs=rngs,
276+
kernel_size=(1, 3, 3),
277+
stride=(1, 2, 2)
278+
)
267279
elif mode == "downsample3d":
268-
self.resample = ZeroPaddedConv2D(dim=dim, rngs=rngs)
280+
# TODO - do I need to transpose?
281+
self.resample = ZeroPaddedConv2D(
282+
dim=dim,
283+
rngs=rngs,
284+
kernel_size=(1, 3, 3),
285+
stride=(1, 2, 2)
286+
)
269287
else:
270288
self.resample = Identity()
271289

src/maxdiffusion/tests/wan_vae_test.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
import unittest
2323
import pytest
2424
from absl.testing import absltest
25-
from ..models.wan.autoencoder_kl_wan import WanCausalConv3d, WanUpsample, AutoencoderKLWan, WanRMS_norm
25+
from ..models.wan.autoencoder_kl_wan import (
26+
WanCausalConv3d,
27+
WanUpsample,
28+
AutoencoderKLWan,
29+
WanRMS_norm,
30+
ZeroPaddedConv2D
31+
)
2632

2733
class WanVaeTest(unittest.TestCase):
2834
def setUp(self):
@@ -66,36 +72,63 @@ def forward(self, x):
6672
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
6773

6874
# --- Test Case 1: images == True ---
69-
model = TorchWanRMS_norm(2)
70-
input_shape = (1, 2, 2, 2, 3)
75+
dim = 96
76+
input_shape = (1, 96, 1, 480, 720)
77+
78+
model = TorchWanRMS_norm(dim)
7179
input = torch.ones(input_shape)
7280
torch_output = model(input)
7381
torch_output_np = torch_output.detach().numpy()
7482

7583
key = jax.random.key(0)
7684
rngs = nnx.Rngs(key)
77-
wanrms_norm = WanRMS_norm(dim=2, rngs=rngs)
78-
input_shape = (1, 2, 2, 2, 3)
85+
wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs)
7986
dummy_input = jnp.ones(input_shape)
8087
output = wanrms_norm(dummy_input)
8188
output_np = np.array(output)
8289
assert np.allclose(output_np, torch_output_np) == True
8390

8491
# --- Test Case 2: images == False ---
85-
model = TorchWanRMS_norm(2, images=False)
86-
input_shape = (1, 2, 2, 2, 3)
92+
model = TorchWanRMS_norm(dim, images=False)
8793
input = torch.ones(input_shape)
8894
torch_output = model(input)
8995
torch_output_np = torch_output.detach().numpy()
9096

9197
key = jax.random.key(0)
9298
rngs = nnx.Rngs(key)
93-
wanrms_norm = WanRMS_norm(dim=2, rngs=rngs, images=False)
94-
input_shape = (1, 2, 2, 2, 3)
99+
wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False)
95100
dummy_input = jnp.ones(input_shape)
96101
output = wanrms_norm(dummy_input)
97102
output_np = np.array(output)
98103
assert np.allclose(output_np, torch_output_np) == True
104+
105+
def test_zero_padded_conv(self):
106+
import torch
107+
import torch.nn as nn
108+
109+
key = jax.random.key(0)
110+
rngs = nnx.Rngs(key)
111+
112+
dim = 96
113+
kernel_size = 3
114+
stride= (2, 2)
115+
resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride))
116+
input_shape = (1, 96, 480, 720)
117+
input = torch.ones(input_shape)
118+
output_torch = resample(input)
119+
assert output_torch.shape == (1, 96, 240, 360)
120+
121+
model = ZeroPaddedConv2D(
122+
dim=dim,
123+
rngs=rngs,
124+
kernel_size=(1, 3, 3),
125+
stride=(1, 2, 2)
126+
)
127+
dummy_input = jnp.ones(input_shape)
128+
dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1))
129+
output = model(dummy_input)
130+
output = jnp.transpose(output, (0, 3, 1, 2))
131+
assert output.shape == (1, 96, 240, 360)
99132

100133
def test_wan_upsample(self):
101134
batch_size=1

0 commit comments

Comments
 (0)