1515 """
1616
1717import os
18+ import torch
19+ import torch .nn as nn
20+ import torch .nn .functional as F
1821import jax
1922import jax .numpy as jnp
2023from flax import nnx
2629 WanCausalConv3d ,
2730 WanUpsample ,
2831 AutoencoderKLWan ,
32+ WanEncoder3d ,
2933 WanRMS_norm ,
34+ WanResample ,
3035 ZeroPaddedConv2D
3136)
3237
33- class WanVaeTest (unittest .TestCase ):
34- def setUp (self ):
35- WanVaeTest .dummy_data = {}
36-
37- # def test_clear_cache(self):
38- # key = jax.random.key(0)
39- # rngs = nnx.Rngs(key)
40- # wan_vae = AutoencoderKLWan(rngs=rngs)
41- # wan_vae.clear_cache()
38+ CACHE_T = 2
4239
43- def test_wanrms_norm (self ):
44- """Test against the Pytorch implementation"""
45- import torch
46- import torch .nn as nn
47- import torch .nn .functional as F
48-
49- class TorchWanRMS_norm (nn .Module ):
40+ class TorchWanRMS_norm (nn .Module ):
5041 r"""
5142 A custom RMS normalization layer.
5243
@@ -70,6 +61,103 @@ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bi
7061
7162 def forward (self , x ):
7263 return F .normalize (x , dim = (1 if self .channel_first else - 1 )) * self .scale * self .gamma + self .bias
64+
65+ class TorchWanResample (nn .Module ):
66+ r"""
67+ A custom resampling module for 2D and 3D data.
68+
69+ Args:
70+ dim (int): The number of input/output channels.
71+ mode (str): The resampling mode. Must be one of:
72+ - 'none': No resampling (identity operation).
73+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
74+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
75+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
76+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
77+ """
78+
79+ def __init__ (self , dim : int , mode : str ) -> None :
80+ super ().__init__ ()
81+ self .dim = dim
82+ self .mode = mode
83+
84+ # layers
85+ if mode == "upsample2d" :
86+ self .resample = nn .Sequential (
87+ WanUpsample (scale_factor = (2.0 , 2.0 ), mode = "nearest-exact" ), nn .Conv2d (dim , dim // 2 , 3 , padding = 1 )
88+ )
89+ elif mode == "upsample3d" :
90+ self .resample = nn .Sequential (
91+ WanUpsample (scale_factor = (2.0 , 2.0 ), mode = "nearest-exact" ), nn .Conv2d (dim , dim // 2 , 3 , padding = 1 )
92+ )
93+ self .time_conv = WanCausalConv3d (dim , dim * 2 , (3 , 1 , 1 ), padding = (1 , 0 , 0 ))
94+
95+ elif mode == "downsample2d" :
96+ self .resample = nn .Sequential (nn .ZeroPad2d ((0 , 1 , 0 , 1 )), nn .Conv2d (dim , dim , 3 , stride = (2 , 2 )))
97+ elif mode == "downsample3d" :
98+ self .resample = nn .Sequential (nn .ZeroPad2d ((0 , 1 , 0 , 1 )), nn .Conv2d (dim , dim , 3 , stride = (2 , 2 )))
99+ self .time_conv = WanCausalConv3d (dim , dim , (3 , 1 , 1 ), stride = (2 , 1 , 1 ), padding = (0 , 0 , 0 ))
100+
101+ else :
102+ self .resample = nn .Identity ()
103+
104+ def forward (self , x , feat_cache = None , feat_idx = [0 ]):
105+ b , c , t , h , w = x .size ()
106+ if self .mode == "upsample3d" :
107+ if feat_cache is not None :
108+ idx = feat_idx [0 ]
109+ if feat_cache [idx ] is None :
110+ feat_cache [idx ] = "Rep"
111+ feat_idx [0 ] += 1
112+ else :
113+ cache_x = x [:, :, - CACHE_T :, :, :].clone ()
114+ if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] != "Rep" :
115+ # cache last frame of last two chunk
116+ cache_x = torch .cat (
117+ [feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2
118+ )
119+ if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None and feat_cache [idx ] == "Rep" :
120+ cache_x = torch .cat ([torch .zeros_like (cache_x ).to (cache_x .device ), cache_x ], dim = 2 )
121+ if feat_cache [idx ] == "Rep" :
122+ x = self .time_conv (x )
123+ else :
124+ x = self .time_conv (x , feat_cache [idx ])
125+ feat_cache [idx ] = cache_x
126+ feat_idx [0 ] += 1
127+
128+ x = x .reshape (b , 2 , c , t , h , w )
129+ x = torch .stack ((x [:, 0 , :, :, :, :], x [:, 1 , :, :, :, :]), 3 )
130+ x = x .reshape (b , c , t * 2 , h , w )
131+ t = x .shape [2 ]
132+ x = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (b * t , c , h , w )
133+ x = self .resample (x )
134+ x = x .view (b , t , x .size (1 ), x .size (2 ), x .size (3 )).permute (0 , 2 , 1 , 3 , 4 )
135+
136+ if self .mode == "downsample3d" :
137+ if feat_cache is not None :
138+ idx = feat_idx [0 ]
139+ if feat_cache [idx ] is None :
140+ feat_cache [idx ] = x .clone ()
141+ feat_idx [0 ] += 1
142+ else :
143+ cache_x = x [:, :, - 1 :, :, :].clone ()
144+ x = self .time_conv (torch .cat ([feat_cache [idx ][:, :, - 1 :, :, :], x ], 2 ))
145+ feat_cache [idx ] = cache_x
146+ feat_idx [0 ] += 1
147+ return x
148+
149+ class WanVaeTest (unittest .TestCase ):
150+ def setUp (self ):
151+ WanVaeTest .dummy_data = {}
152+
153+ # def test_clear_cache(self):
154+ # key = jax.random.key(0)
155+ # rngs = nnx.Rngs(key)
156+ # wan_vae = AutoencoderKLWan(rngs=rngs)
157+ # wan_vae.clear_cache()
158+
159+ def test_wanrms_norm (self ):
160+ """Test against the Pytorch implementation"""
73161
74162 # --- Test Case 1: images == True ---
75163 dim = 96
@@ -103,8 +191,6 @@ def forward(self, x):
103191 assert np .allclose (output_np , torch_output_np ) == True
104192
105193 def test_zero_padded_conv (self ):
106- import torch
107- import torch .nn as nn
108194
109195 key = jax .random .key (0 )
110196 rngs = nnx .Rngs (key )
@@ -148,6 +234,49 @@ def test_wan_upsample(self):
148234 # --- Test Case 1: depth == 1 ---
149235 output = upsample (dummy_input )
150236 assert output .shape == (1 , 1 , 64 , 64 , 3 )
237+
238+ def test_wan_resample (self ):
239+ # TODO - needs to test all modes - upsample2d, upsample3d, downsample2d, downsample3d and identity
240+ key = jax .random .key (0 )
241+ rngs = nnx .Rngs (key )
242+
243+ # --- Test Case 1: downsample2d ---
244+ batch = 1
245+ dim = 96
246+ t = 1
247+ h = 480
248+ w = 720
249+ mode = "downsample2d"
250+ input_shape = (batch , dim , t , h , w )
251+ expected_output_shape = (1 , dim , 1 , 240 , 360 )
252+ # output dim should be (1, 96, 1, 480, 720)
253+ dummy_input = torch .ones (input_shape )
254+ torch_wan_resample = TorchWanResample (
255+ dim = dim ,
256+ mode = mode
257+ )
258+ torch_output = torch_wan_resample (dummy_input )
259+ assert torch_output .shape == (batch , dim , t , h // 2 , w // 2 )
260+
261+ wan_resample = WanResample (
262+ dim ,
263+ mode = mode ,
264+ rngs = rngs
265+ )
266+ # channels is always last here
267+ input_shape = (batch , t , h , w , dim )
268+ dummy_input = jnp .ones (input_shape )
269+ output = wan_resample (dummy_input )
270+ assert output .shape == (batch , t , h // 2 , h // 2 , dim )
271+ breakpoint ()
272+
273+ # --- Test Case 1: downsample3d ---
274+ dim = 192
275+ input_shape = (1 , dim , 1 , 240 , 360 )
276+ torch_wan_resample = WanResample (
277+ dim = dim ,
278+ mode = "downsample3d"
279+ )
151280
152281 def test_3d_conv (self ):
153282 key = jax .random .key (0 )
@@ -189,5 +318,36 @@ def test_3d_conv(self):
189318 output_with_larger_cache = causal_conv_layer (dummy_input , cache_x = dummy_larger_cache )
190319 assert output_with_larger_cache .shape == (1 , 10 , 32 , 32 , 16 )
191320
321+ def test_wan_encode (self ):
322+ key = jax .random .key (0 )
323+ rngs = nnx .Rngs (key )
324+ dim = 96
325+ z_dim = 32
326+ dim_mult = [1 , 2 , 4 , 4 ]
327+ num_res_blocks = 2
328+ attn_scales = []
329+ temperal_downsample = [False , True , True ]
330+ nonlinearity = "silu"
331+ wan_encoder = WanEncoder3d (
332+ rngs = rngs ,
333+ dim = dim ,
334+ z_dim = z_dim ,
335+ dim_mult = dim_mult ,
336+ num_res_blocks = num_res_blocks ,
337+ attn_scales = attn_scales ,
338+ temperal_downsample = temperal_downsample ,
339+ non_linearity = nonlinearity
340+ )
341+ batch = 1
342+ channels = 3
343+ t = 49
344+ height = 480
345+ width = 720
346+ input_shape = (batch , channels , t , height , width )
347+ input = jnp .ones (input_shape )
348+ output = wan_encoder (input )
349+
350+
351+
192352if __name__ == "__main__" :
193353 absltest .main ()
0 commit comments