@@ -168,18 +168,14 @@ def __init__(self, dim: int, mode: str, rngs: nnx.Rngs, mesh: jax.sharding.Mesh
168168 self .dim = dim
169169 self .mode = mode
170170
171- # FIX: Removed pre-initialization of attributes to None to avoid NNX errors.
172-
173171 if mode == "upsample2d" :
174172 self .resample = nnx .Sequential (
175173 WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
176174 nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
177175 )
178176 elif mode == "upsample3d" :
179- self .resample = nnx .Sequential (
180- WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" ),
181- nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
182- )
177+ self .upsample = WanUpsample (scale_factor = (2.0 , 2.0 ), method = "nearest" )
178+ self .conv = nnx .Conv (dim , dim // 2 , kernel_size = (3 , 3 ), padding = "SAME" , use_bias = True , rngs = rngs , kernel_init = nnx .with_partitioning (nnx .initializers .xavier_uniform (), (None , None , None , "conv_out" )), dtype = dtype , param_dtype = weights_dtype , precision = precision )
183179 self .time_conv = WanCausalConv3d (rngs = rngs , in_channels = dim , out_channels = dim * 2 , kernel_size = (3 , 1 , 1 ), padding = (1 , 0 , 0 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
184180 elif mode == "downsample2d" :
185181 self .resample = ZeroPaddedConv2D (dim = dim , rngs = rngs , kernel_size = (3 , 3 ), stride = (2 , 2 ), mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
@@ -220,7 +216,8 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
220216
221217 b , t , h , w , c = x .shape
222218 x = x .reshape (b * t , h , w , c )
223- x = self .resample (x ) # Sequential
219+ x = self .upsample (x )
220+ x = self .conv (x )
224221 h_new , w_new , c_new = x .shape [1 :]
225222 x = x .reshape (b , t , h_new , w_new , c_new )
226223
@@ -234,18 +231,19 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None) -> Tuple[jax.Arra
234231 elif self .mode == "downsample3d" :
235232 b , t , h , w , c = x .shape
236233 x = x .reshape (b * t , h , w , c )
237- x , _ = self .resample (x , None ) # ZeroPaddedConv2D
234+ x , _ = self .resample (x , None ) # Fixed: use self.resample not self.downsample_conv
238235 h_new , w_new , c_new = x .shape [1 :]
239236 x = x .reshape (b , t , h_new , w_new , c_new )
240237
241238 x , tc_cache = self .time_conv (x , cache .get ("time_conv" ))
242239 new_cache ["time_conv" ] = tc_cache
243240
244241 else :
245- if isinstance (self .resample , Identity ):
246- x , _ = self .resample (x , None )
247- else :
248- x = self .resample (x )
242+ if hasattr (self , "resample" ):
243+ if isinstance (self .resample , Identity ):
244+ x , _ = self .resample (x , None )
245+ else :
246+ x = self .resample (x )
249247
250248 return x , new_cache
251249
@@ -532,6 +530,10 @@ def __init__(self, rngs: nnx.Rngs, base_dim: int = 96, z_dim: int = 16, dim_mult
532530 self .temperal_downsample = temperal_downsample
533531 self .temporal_upsample = temperal_downsample [::- 1 ]
534532
533+ # MISSING attributes added back
534+ self .latents_mean = latents_mean
535+ self .latents_std = latents_std
536+
535537 self .encoder = WanEncoder3d (rngs = rngs , dim = base_dim , z_dim = z_dim * 2 , dim_mult = dim_mult , num_res_blocks = num_res_blocks , attn_scales = attn_scales , temperal_downsample = temperal_downsample , dropout = dropout , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
536538 self .quant_conv = WanCausalConv3d (rngs = rngs , in_channels = z_dim * 2 , out_channels = z_dim * 2 , kernel_size = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
537539 self .post_quant_conv = WanCausalConv3d (rngs = rngs , in_channels = z_dim , out_channels = z_dim , kernel_size = 1 , mesh = mesh , dtype = dtype , weights_dtype = weights_dtype , precision = precision )
0 commit comments