@@ -200,8 +200,6 @@ def __init__(
200200 precision : jax .lax .Precision = None ,
201201 attention : str = "dot_product" ,
202202 ):
203- kernel_size = _canonicalize_tuple (kernel_size , 3 , "kernel_size" )
204- stride = _canonicalize_tuple (stride , 3 , "stride" )
205203 self .conv = nnx .Conv (dim , dim , kernel_size = kernel_size , strides = stride , use_bias = True , rngs = rngs )
206204
207205 def __call__ (self , x ):
@@ -233,19 +231,19 @@ def __init__(
233231 nnx .Conv (
234232 dim ,
235233 dim // 2 ,
236- kernel_size = (1 , 3 , 3 ),
234+ kernel_size = (3 , 3 ),
237235 padding = "SAME" ,
238236 use_bias = True ,
239237 rngs = rngs ,
240238 ),
241239 )
242240 elif mode == "upsample3d" :
243241 self .resample = nnx .Sequential (
244- WanUpsample (scale_factor = (2.0 , 2.0 , 2.0 ), method = "nearest" ),
242+ WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
245243 nnx .Conv (
246244 dim ,
247245 dim // 2 ,
248- kernel_size = (1 , 3 , 3 ),
246+ kernel_size = (3 , 3 ),
249247 padding = "SAME" ,
250248 use_bias = True ,
251249 rngs = rngs ,
@@ -259,11 +257,9 @@ def __init__(
259257 padding = (1 , 0 , 0 ),
260258 )
261259 elif mode == "downsample2d" :
262- # TODO - do I need to transpose?
263- self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (1 , 3 , 3 ), stride = (1 , 2 , 2 ))
260+ self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ))
264261 elif mode == "downsample3d" :
265- # TODO - do I need to transpose?
266- self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (1 , 3 , 3 ), stride = (1 , 2 , 2 ))
262+ self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ))
267263 self .time_conv = WanCausalConv3d (
268264 rngs = rngs , in_channels = dim , out_channels = dim , kernel_size = (3 , 1 , 1 ), stride = (2 , 1 , 1 ), padding = (0 , 0 , 0 )
269265 )
@@ -334,7 +330,6 @@ def __init__(
334330 self .norm1 = WanRMS_norm (dim = in_dim , rngs = rngs , images = False , channel_first = False )
335331 self .conv1 = WanCausalConv3d (rngs = rngs , in_channels = in_dim , out_channels = out_dim , kernel_size = 3 , padding = 1 )
336332 self .norm2 = WanRMS_norm (dim = out_dim , rngs = rngs , images = False , channel_first = False )
337- self .dropout = nnx .Dropout (dropout , rngs = rngs )
338333 self .conv2 = WanCausalConv3d (rngs = rngs , in_channels = out_dim , out_channels = out_dim , kernel_size = 3 , padding = 1 )
339334 self .conv_shortcut = (
340335 WanCausalConv3d (rngs = rngs , in_channels = in_dim , out_channels = out_dim , kernel_size = 1 )
@@ -363,7 +358,6 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=[0]):
363358
364359 x = self .norm2 (x )
365360 x = self .nonlinearity (x )
366- x = self .dropout (x )
367361
368362 if feat_cache is not None :
369363 idx = feat_idx [0 ]
@@ -384,8 +378,8 @@ class WanAttentionBlock(nnx.Module):
384378 def __init__ (self , dim : int , rngs : nnx .Rngs ):
385379 self .dim = dim
386380 self .norm = WanRMS_norm (rngs = rngs , dim = dim , channel_first = False )
387- self .to_qkv = nnx .Conv (in_features = dim , out_features = dim * 3 , kernel_size = 1 , rngs = rngs )
388- self .proj = nnx .Conv (in_features = dim , out_features = dim , kernel_size = 1 , rngs = rngs )
381+ self .to_qkv = nnx .Conv (in_features = dim , out_features = dim * 3 , kernel_size = ( 1 , 1 ) , rngs = rngs )
382+ self .proj = nnx .Conv (in_features = dim , out_features = dim , kernel_size = ( 1 , 1 ) , rngs = rngs )
389383
390384 def __call__ (self , x : jax .Array ):
391385 batch_size , time , height , width , channels = x .shape
@@ -801,8 +795,6 @@ def _encode(self, x: jax.Array):
801795 x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
802796 assert x .shape [- 1 ] == 3 , f"Expected input shape (N, D, H, W, 3), got { x .shape } "
803797
804- # self.clear_cache()
805-
806798 t = x .shape [1 ]
807799 iter_ = 1 + (t - 1 ) // 4
808800 for i in range (iter_ ):
@@ -854,8 +846,8 @@ def _decode(self, z: jax.Array, return_dict: bool = True) -> Union[FlaxDecoderOu
854846 def decode (self , z : jax .Array , return_dict : bool = True ) -> Union [FlaxDecoderOutput , jax .Array ]:
855847 if z .shape [- 1 ] != self .z_dim :
856848 # reshape channel last for JAX
857- x = jnp .transpose (x , (0 , 2 , 3 , 4 , 1 ))
858- assert x .shape [- 1 ] == self .z_dim , f"Expected input shape (N, D, H, W, { self .z_dim } , got { x .shape } "
849+ z = jnp .transpose (z , (0 , 2 , 3 , 4 , 1 ))
850+ assert z .shape [- 1 ] == self .z_dim , f"Expected input shape (N, D, H, W, { self .z_dim } , got { z .shape } "
859851 decoded = self ._decode (z ).sample
860852 if not return_dict :
861853 return (decoded ,)
0 commit comments