|
22 | 22 | import unittest |
23 | 23 | import pytest |
24 | 24 | from absl.testing import absltest |
25 | | -from ..models.wan.autoencoder_kl_wan import WanCausalConv3d, WanUpsample, AutoencoderKLWan, WanRMS_norm |
| 25 | +from ..models.wan.autoencoder_kl_wan import ( |
| 26 | + WanCausalConv3d, |
| 27 | + WanUpsample, |
| 28 | + AutoencoderKLWan, |
| 29 | + WanRMS_norm, |
| 30 | + ZeroPaddedConv2D |
| 31 | +) |
26 | 32 |
|
27 | 33 | class WanVaeTest(unittest.TestCase): |
28 | 34 | def setUp(self): |
@@ -66,36 +72,63 @@ def forward(self, x): |
66 | 72 | return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias |
67 | 73 |
|
68 | 74 | # --- Test Case 1: images == True --- |
69 | | - model = TorchWanRMS_norm(2) |
70 | | - input_shape = (1, 2, 2, 2, 3) |
| 75 | + dim = 96 |
| 76 | + input_shape = (1, 96, 1, 480, 720) |
| 77 | + |
| 78 | + model = TorchWanRMS_norm(dim) |
71 | 79 | input = torch.ones(input_shape) |
72 | 80 | torch_output = model(input) |
73 | 81 | torch_output_np = torch_output.detach().numpy() |
74 | 82 |
|
75 | 83 | key = jax.random.key(0) |
76 | 84 | rngs = nnx.Rngs(key) |
77 | | - wanrms_norm = WanRMS_norm(dim=2, rngs=rngs) |
78 | | - input_shape = (1, 2, 2, 2, 3) |
| 85 | + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs) |
79 | 86 | dummy_input = jnp.ones(input_shape) |
80 | 87 | output = wanrms_norm(dummy_input) |
81 | 88 | output_np = np.array(output) |
82 | 89 | assert np.allclose(output_np, torch_output_np) == True |
83 | 90 |
|
84 | 91 | # --- Test Case 2: images == False --- |
85 | | - model = TorchWanRMS_norm(2, images=False) |
86 | | - input_shape = (1, 2, 2, 2, 3) |
| 92 | + model = TorchWanRMS_norm(dim, images=False) |
87 | 93 | input = torch.ones(input_shape) |
88 | 94 | torch_output = model(input) |
89 | 95 | torch_output_np = torch_output.detach().numpy() |
90 | 96 |
|
91 | 97 | key = jax.random.key(0) |
92 | 98 | rngs = nnx.Rngs(key) |
93 | | - wanrms_norm = WanRMS_norm(dim=2, rngs=rngs, images=False) |
94 | | - input_shape = (1, 2, 2, 2, 3) |
| 99 | + wanrms_norm = WanRMS_norm(dim=dim, rngs=rngs, images=False) |
95 | 100 | dummy_input = jnp.ones(input_shape) |
96 | 101 | output = wanrms_norm(dummy_input) |
97 | 102 | output_np = np.array(output) |
98 | 103 | assert np.allclose(output_np, torch_output_np) == True |
| 104 | + |
| 105 | + def test_zero_padded_conv(self): |
| 106 | + import torch |
| 107 | + import torch.nn as nn |
| 108 | + |
| 109 | + key = jax.random.key(0) |
| 110 | + rngs = nnx.Rngs(key) |
| 111 | + |
| 112 | + dim = 96 |
| 113 | + kernel_size = 3 |
| 114 | + stride= (2, 2) |
| 115 | + resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, kernel_size, stride=stride)) |
| 116 | + input_shape = (1, 96, 480, 720) |
| 117 | + input = torch.ones(input_shape) |
| 118 | + output_torch = resample(input) |
| 119 | + assert output_torch.shape == (1, 96, 240, 360) |
| 120 | + |
| 121 | + model = ZeroPaddedConv2D( |
| 122 | + dim=dim, |
| 123 | + rngs=rngs, |
| 124 | + kernel_size=(1, 3, 3), |
| 125 | + stride=(1, 2, 2) |
| 126 | + ) |
| 127 | + dummy_input = jnp.ones(input_shape) |
| 128 | + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) |
| 129 | + output = model(dummy_input) |
| 130 | + output = jnp.transpose(output, (0, 3, 1, 2)) |
| 131 | + assert output.shape == (1, 96, 240, 360) |
99 | 132 |
|
100 | 133 | def test_wan_upsample(self): |
101 | 134 | batch_size=1 |
|
0 commit comments